diff --git a/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py b/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py index 8af2ed253..f6f6aaa41 100644 --- a/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py +++ b/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py @@ -6,10 +6,14 @@ import torch import glob import warnings -from modelscope import snapshot_download -from transformers import Qwen2VLForConditionalGeneration, AutoConfig, AutoTokenizer -from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig from tqdm import tqdm + +from transformers import AutoConfig, AutoTokenizer, AutoProcessor + +from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig + from safetensors.torch import safe_open from dashinfer import allspark from dashinfer.allspark.model_loader import HuggingFaceModel, ModelSerializerException @@ -59,25 +63,58 @@ def load_model( # the open-source model can be loaded by huggingface try: if not os.path.isdir(self.hf_model_path): + from modelscope import snapshot_download self.hf_model_path = snapshot_download(self.hf_model_path) - self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained( - self.hf_model_path, - trust_remote_code=self.trust_remote_code, - torch_dtype=dtype_to_torch_dtype(self.data_type), - device_map="cpu", - **kwargs, - ).eval() - self.vit_config = Qwen2VLVisionConfig.from_pretrained( - self.hf_model_path, - trust_remote_code=True, - revision=None, - code_revision=None, - ) - self.tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - trust_remote_code=self.trust_remote_code, - **kwargs, + + # Read config to determine model architecture + self.hf_model_config = AutoConfig.from_pretrained( + self.hf_model_path, trust_remote_code=self.trust_remote_code ) + + if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures: + self.torch_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + torch_dtype=dtype_to_torch_dtype(self.data_type), + device_map="cpu", + **kwargs, + ).eval() + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + **kwargs, + ) + self.processor = AutoProcessor.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + **kwargs, + ) + self.vit_config = Qwen2_5_VLVisionConfig.from_pretrained( + self.hf_model_path, + trust_remote_code=True, + revision=None, + code_revision=None, + ) + else: + self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + torch_dtype=dtype_to_torch_dtype(self.data_type), + device_map="cpu", + **kwargs, + ).eval() + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + **kwargs, + ) + self.vit_config = Qwen2VLVisionConfig.from_pretrained( + self.hf_model_path, + trust_remote_code=True, + revision=None, + code_revision=None, + ) + pass except Exception as e: print( f"exception when load model: {self.hf_model_path} , exception: {e}" @@ -102,10 +139,10 @@ def read_model_config(self): self.hf_model_config = AutoConfig.from_pretrained( self.hf_model_path, trust_remote_code=self.trust_remote_code ) - self.adapter = QWen2ConfigAdapter(self.hf_model_config) - self.as_model_config = self.adapter.model_config - if self.user_set_data_type is None: - self.data_type = self.adapter.get_model_data_type() + self.adapter = QWen2ConfigAdapter(self.hf_model_config) + self.as_model_config = self.adapter.model_config + if self.user_set_data_type is None: + self.data_type = self.adapter.get_model_data_type() return self def serialize( @@ -127,17 +164,26 @@ def serialize( onnx_trt_obj.export_onnx(onnxFile) onnx_trt_obj.generate_trt_engine(onnxFile, self.vision_model_path) elif self.vision_engine == "transformers": - visual_model = Qwen2VLForConditionalGeneration.from_pretrained( + if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures: + visual_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.hf_model_path, trust_remote_code=self.trust_remote_code, torch_dtype=dtype_to_torch_dtype(self.data_type), - device_map="cpu", - attn_implementation="flash_attention_2", + device_map="auto", + attn_implementation="sdpa", + ).visual.eval() + else: + visual_model = Qwen2VLForConditionalGeneration.from_pretrained( + self.hf_model_path, + trust_remote_code=self.trust_remote_code, + torch_dtype=dtype_to_torch_dtype(self.data_type), + device_map="auto", + attn_implementation="sdpa", ).visual.eval() self.vision_model_path = visual_model else: raise ValueError(f"unsupported engine {self.vision_engine}") - + # Convert Allspark LLM enable_quant = False weight_only_quant=False diff --git a/multimodal/requirements.txt b/multimodal/requirements.txt index 5e87ce66b..c96f0909a 100644 --- a/multimodal/requirements.txt +++ b/multimodal/requirements.txt @@ -1,9 +1,9 @@ dashinfer@https://github.com/modelscope/dash-infer/releases/download/v2.0.0-rc3/dashinfer-2.0.0rc3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl av -numpy==1.24.3 -requests==2.32.3 -nvtx==0.2.10 -transformers>=4.45.0 +numpy>=1.24.3 +requests>=2.32.3 +nvtx>=0.2.10 +transformers>=4.48.9 cachetools>=5.4.0 six tiktoken @@ -12,7 +12,7 @@ shortuuid fastapi pydantic_settings uvicorn -cmake==3.22.6 +cmake>=3.22.6 modelscope aiohttp onnx