Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ class VisionExportableModule(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
self._precomputed_pos = False

def _precompute_vision_positions(self, image_grid_thw: torch.Tensor):
"""Pre-compute position-related values for M-RoPE vision encoders (e.g. Qwen3-VL).

The visual encoder uses grid_thw values in data-dependent ops (torch.linspace,
repeat_interleave) that torch.export cannot trace. We compute them eagerly and
store as buffers so they become constants in the exported graph.
"""
visual = self.model.model.visual
with torch.no_grad():
self.register_buffer("_pos_embeds", visual.fast_pos_embed_interpolate(image_grid_thw))
self.register_buffer("_rotary_pos_emb", visual.rot_pos_emb(image_grid_thw))
cu = torch.repeat_interleave(
image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
self.register_buffer("_cu_seqlens", torch.nn.functional.pad(cu, (1, 0), value=0))
self._precomputed_pos = True

def prepare_export_inputs(self):
# 1. Get export inputs
Expand Down Expand Up @@ -73,15 +91,40 @@ def prepare_export_inputs(self):
)
export_inputs = processed_inputs["pixel_values"].to(dtype=self.model.dtype)

if "image_grid_thw" in processed_inputs:
self._precompute_vision_positions(processed_inputs["image_grid_thw"])

# 2. Get export dynamic shapes
dynamic_shapes = None # No batching for now.

return export_inputs, dynamic_shapes

def _forward_with_precomputed_pos(self, input_features: torch.FloatTensor):
"""Forward through the visual encoder using pre-computed position data."""
visual = self.model.model.visual
hidden_states = visual.patch_embed(input_features.type(visual.dtype))
hidden_states = hidden_states + self._pos_embeds

seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len, -1)
rotary = self._rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary, rotary), dim=-1)
position_embeddings = (emb.cos(), emb.sin())

for blk in visual.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=self._cu_seqlens,
position_embeddings=position_embeddings,
)
return visual.merger(hidden_states)

def forward(
self,
input_features: torch.FloatTensor,
):
if self._precomputed_pos:
return self._forward_with_precomputed_pos(input_features)
image_embeds = self.model.get_image_features(input_features)
if isinstance(image_embeds, list):
image_embeds = torch.stack(image_embeds)
Expand Down Expand Up @@ -274,6 +317,31 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module):
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"

def _register_mrope_hook(self):
"""Register a forward pre-hook on M-RoPE models (e.g. Qwen3-VL) to inject position_ids.

During text decoder export, only inputs_embeds and cache_position are provided (no
input_ids). M-RoPE models try to compute position_ids via get_rope_index which
requires input_ids and crashes. This hook injects position_ids derived from
cache_position so the model skips that code path entirely.

Returns the hook handle (or None if not applicable) so the caller can remove it.
"""
inner_model = getattr(self.model, "model", None)
if inner_model is None or not hasattr(inner_model, "get_rope_index"):
return None

def _inject_mrope_position_ids(module, args, kwargs):
if kwargs.get("position_ids") is None and kwargs.get("input_ids") is None:
inputs_embeds = kwargs.get("inputs_embeds")
cache_position = kwargs.get("cache_position")
if inputs_embeds is not None and cache_position is not None:
batch_size = inputs_embeds.shape[0]
kwargs["position_ids"] = cache_position.view(1, 1, -1).expand(3, batch_size, -1)
return args, kwargs

return inner_model.register_forward_pre_hook(_inject_mrope_position_ids, with_kwargs=True)

def export(
self,
) -> Dict[str, ExportedProgram]:
Expand Down Expand Up @@ -321,12 +389,22 @@ def export(
# Move inputs to the same device as the model
inputs_embeds = inputs_embeds.to(self.model.device)
cache_position = cache_position.to(self.model.device)

# M-RoPE models (e.g. Qwen3-VL) compute position_ids via get_rope_index which
# requires input_ids. During text decoder export only inputs_embeds is provided,
# so we inject position_ids derived from cache_position to skip that code path.
# For text-only decode, all 3 M-RoPE dimensions equal cache_position.
mrope_hook = self._register_mrope_hook()

exported_program = exportable_module.export(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=True,
)
if mrope_hook is not None:
mrope_hook.remove()

# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
# e.g. output of update_cache is transposed and
Expand Down
37 changes: 21 additions & 16 deletions optimum/exporters/executorch/tasks/multimodal_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os.path

import torchao
from transformers import AutoConfig, AutoModelForPreTraining, GenerationConfig
from transformers import AutoConfig, AutoModelForImageTextToText, AutoModelForPreTraining, GenerationConfig

from ..integrations import MultiModalTextToTextExportableModule
from ..quantization import quantize_model_
Expand Down Expand Up @@ -96,19 +96,19 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if hasattr(config, "use_cache") and config.use_cache is False:
config.use_cache = True

# Using `AutoModelForPreTraining` since it usually routes to the correct model variant and there is no
# auto model class that captures both audio and image.
# The correct model variant we are looking for is <Model>ForConditionalGeneration, since it is the top-level
# model and thus will always contain the necessary model components. As an example of why this is needed,
# if you just use Gemma3Model instead of Gemma3ForConditionalGeneration, Gemma3Model (which is the decoder part)
# will not contain the LM head, which is only applied in the latter.
eager_model = AutoModelForPreTraining.from_pretrained(
model_name_or_path,
# We want the <Model>ForConditionalGeneration variant since it's the top-level model containing all
# necessary components (decoder + LM head + encoder). AutoModelForPreTraining works for some models
# (e.g. Gemma3) but not all (e.g. Qwen3-VL), so we fall back to AutoModelForImageTextToText.
from_pretrained_kwargs = dict(
device_map=device,
dtype=dtype,
config=config,
attn_implementation=attn_implementation,
)
try:
eager_model = AutoModelForPreTraining.from_pretrained(model_name_or_path, **from_pretrained_kwargs)
except ValueError:
eager_model = AutoModelForImageTextToText.from_pretrained(model_name_or_path, **from_pretrained_kwargs)
eager_model.generation_config = GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
Expand All @@ -120,15 +120,20 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
},
)

# Find the modality (only one modality of either image/audio is supported at the moment).
if len(eager_model.input_modalities) != 2:
# Find the primary non-text modality. We pick the first of "image" or "audio" since those
# are the modalities supported downstream. Models like Qwen3-VL report ("image", "video", "text")
# but video shares the same visual encoder as image, so "image" is the right pick.
non_text_modalities = [m for m in eager_model.input_modalities if m != "text"]
modality = None
for candidate in ("image", "audio"):
if candidate in non_text_modalities:
modality = candidate
break
if modality is None:
raise AttributeError(
"Only one modality is supported for multimodal models at the moment. The input modalities must be either ['text', 'image'] or ['text, 'audio']"
f"No supported non-text modality found for {model_name_or_path}. "
f"Got modalities: {eager_model.input_modalities}. Expected 'image' or 'audio'."
)
for input_modality in eager_model.input_modalities:
if input_modality == "text":
continue
modality = input_modality
eager_encoder = eager_model.get_encoder(modality)

# Need to do this since apparently when nested modules (e.g. model.language_model) access the .property
Expand Down