Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 74 additions & 28 deletions multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import torch
import glob
import warnings
from modelscope import snapshot_download
from transformers import Qwen2VLForConditionalGeneration, AutoConfig, AutoTokenizer
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
from tqdm import tqdm

from transformers import AutoConfig, AutoTokenizer, AutoProcessor

from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig

from safetensors.torch import safe_open
from dashinfer import allspark
from dashinfer.allspark.model_loader import HuggingFaceModel, ModelSerializerException
Expand Down Expand Up @@ -59,25 +63,58 @@ def load_model(
# the open-source model can be loaded by huggingface
try:
if not os.path.isdir(self.hf_model_path):
from modelscope import snapshot_download
self.hf_model_path = snapshot_download(self.hf_model_path)
self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
torch_dtype=dtype_to_torch_dtype(self.data_type),
device_map="cpu",
**kwargs,
).eval()
self.vit_config = Qwen2VLVisionConfig.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
revision=None,
code_revision=None,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
**kwargs,

# Read config to determine model architecture
self.hf_model_config = AutoConfig.from_pretrained(
self.hf_model_path, trust_remote_code=self.trust_remote_code
)

if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures:
self.torch_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
torch_dtype=dtype_to_torch_dtype(self.data_type),
device_map="cpu",
**kwargs,
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
**kwargs,
)
self.processor = AutoProcessor.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
**kwargs,
)
self.vit_config = Qwen2_5_VLVisionConfig.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
revision=None,
code_revision=None,
)
else:
self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
torch_dtype=dtype_to_torch_dtype(self.data_type),
device_map="cpu",
**kwargs,
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
**kwargs,
)
self.vit_config = Qwen2VLVisionConfig.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
revision=None,
code_revision=None,
)
pass
except Exception as e:
print(
f"exception when load model: {self.hf_model_path} , exception: {e}"
Expand All @@ -102,10 +139,10 @@ def read_model_config(self):
self.hf_model_config = AutoConfig.from_pretrained(
self.hf_model_path, trust_remote_code=self.trust_remote_code
)
self.adapter = QWen2ConfigAdapter(self.hf_model_config)
self.as_model_config = self.adapter.model_config
if self.user_set_data_type is None:
self.data_type = self.adapter.get_model_data_type()
self.adapter = QWen2ConfigAdapter(self.hf_model_config)
self.as_model_config = self.adapter.model_config
if self.user_set_data_type is None:
self.data_type = self.adapter.get_model_data_type()
return self

def serialize(
Expand All @@ -127,17 +164,26 @@ def serialize(
onnx_trt_obj.export_onnx(onnxFile)
onnx_trt_obj.generate_trt_engine(onnxFile, self.vision_model_path)
elif self.vision_engine == "transformers":
visual_model = Qwen2VLForConditionalGeneration.from_pretrained(
if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures:
visual_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
torch_dtype=dtype_to_torch_dtype(self.data_type),
device_map="cpu",
attn_implementation="flash_attention_2",
device_map="auto",
attn_implementation="sdpa",
).visual.eval()
else:
visual_model = Qwen2VLForConditionalGeneration.from_pretrained(
self.hf_model_path,
trust_remote_code=self.trust_remote_code,
torch_dtype=dtype_to_torch_dtype(self.data_type),
device_map="auto",
attn_implementation="sdpa",
).visual.eval()
self.vision_model_path = visual_model
else:
raise ValueError(f"unsupported engine {self.vision_engine}")

# Convert Allspark LLM
enable_quant = False
weight_only_quant=False
Expand Down
10 changes: 5 additions & 5 deletions multimodal/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
dashinfer@https://github.com/modelscope/dash-infer/releases/download/v2.0.0-rc3/dashinfer-2.0.0rc3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
av
numpy==1.24.3
requests==2.32.3
nvtx==0.2.10
transformers>=4.45.0
numpy>=1.24.3
requests>=2.32.3
nvtx>=0.2.10
transformers>=4.48.9
cachetools>=5.4.0
six
tiktoken
Expand All @@ -12,7 +12,7 @@ shortuuid
fastapi
pydantic_settings
uvicorn
cmake==3.22.6
cmake>=3.22.6
modelscope
aiohttp
onnx
Expand Down
Loading