Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ def _prepare_export_inputs(self):
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is None:
sliding_window = self.metadata.get("sliding_window", float("inf"))
max_dim = min(max_seq_len, sliding_window) - 1
if self.use_custom_kv_cache and self.use_custom_sdpa:
max_dim = max_seq_len - 1
else:
max_dim = min(max_seq_len, sliding_window) - 1
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim)
dynamic_shapes = {
"input_ids": {1: seq_len_dim},
Expand Down Expand Up @@ -499,22 +502,47 @@ def export(
f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}"
)

exportable_module = TorchExportableModuleForDecoderOnlyLM(
self.model,
)
self._register_custom_attention(exportable_module)

if self.use_custom_kv_cache:
from optimum.executorch.attentions.custom_kv_cache import (
replace_with_et_custom_kv_cache,
max_seq_len = self.metadata.get("get_max_seq_len")
sliding_window = self.metadata.get("sliding_window")
use_ring_cache = self.use_custom_kv_cache and self.use_custom_sdpa and sliding_window is not None

if use_ring_cache:
from transformers.integrations.executorch import TorchExportableModuleWithHybridCache
from optimum.executorch.attentions.custom_kv_cache import ETCustomHybridCache

# Bypass TorchExportableModuleWithHybridCache.__init__ — it calls StaticCache which
# caps sliding layers to sliding_window via StaticSlidingWindowLayer, baking a
# <= sliding_window guard into torch.export. Instead, directly install
# ETCustomHybridCache sized to max_seq_len, then patch sliding layer max_cache_len
# so get_mask_sizes() returns max_seq_len during tracing.
exportable_module_inner = TorchExportableModuleWithHybridCache.__new__(TorchExportableModuleWithHybridCache)
torch.nn.Module.__init__(exportable_module_inner)
exportable_module_inner.model = self.model
exportable_module_inner.cache = ETCustomHybridCache(
config=self.model.config, max_batch_size=1, max_cache_len=max_seq_len,
device=self.model.device, dtype=self.model.dtype,
)
for layer in exportable_module_inner.cache.layers:
if layer.is_sliding:
layer.max_cache_len = max_seq_len
for i in range(len(exportable_module_inner.cache.kv_cache)):
exportable_module_inner.register_buffer(
f"key_cache_{i}", exportable_module_inner.cache.kv_cache[i].k_cache, persistent=False)
exportable_module_inner.register_buffer(
f"value_cache_{i}", exportable_module_inner.cache.kv_cache[i].v_cache, persistent=False)
if exportable_module_inner.cache.layers[i].is_sliding:
exportable_module_inner.register_buffer(
f"cache_positions_{i}",
exportable_module_inner.cache.kv_cache[i].cache_positions_manager.cache_positions,
persistent=False,
)
exportable_module = TorchExportableModuleForDecoderOnlyLM.__new__(TorchExportableModuleForDecoderOnlyLM)
torch.nn.Module.__init__(exportable_module)
exportable_module.model = exportable_module_inner
else:
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model)

replace_with_et_custom_kv_cache(
exportable_module.model,
self.model.config,
self.model.generation_config,
self.model.dtype,
)
self._register_custom_attention(exportable_module)

with torch.no_grad():
exported_program = exportable_module.export(
Expand Down