Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 164 additions & 6 deletions optimum/executorch/attentions/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,56 @@
except ImportError:
raise ImportError("ExecutorTorch is not installed. Please install it to use Custom Cache.")

try:
from executorch.examples.models.llama.source_transformation.attention_sink import (
CachePositionsManagerWithSink,
_create_causal_mask_for_attention_sink,
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to wait for et side change to merge first

)
except ImportError:
CachePositionsManagerWithSink = None
_create_causal_mask_for_attention_sink = None


class CustomRingKVCacheWithSink(CustomKVCache):
"""Ring buffer KV cache with attention sink — preserves first sink_size tokens."""

def __init__(
self,
max_batch_size: int,
max_context_length: int,
n_heads: int,
head_dim: int,
sink_size: int,
dtype=torch.float32,
):
assert CachePositionsManagerWithSink is not None, (
"CachePositionsManagerWithSink not available. "
"Install ExecuTorch with attention sink support."
)
super().__init__(max_batch_size, max_context_length, n_heads, head_dim, dtype)
self.sink_size = sink_size
self.window_size = max_context_length - sink_size
self.is_ring_buffer = True
self.cache_positions_manager = CachePositionsManagerWithSink(
max_context_length, sink_size
)

def update(self, input_pos, k_val, v_val):
seq_len = k_val.transpose(1, 2).size(1)
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
).unsqueeze(0)
return super().update(input_pos, k_val, v_val, indices)

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
return _create_causal_mask_for_attention_sink(
self.cache_positions_manager.cache_positions,
self.window_size,
self.sink_size,
start_pos,
seq_len,
)


class ETCustomStaticCache(StaticCache):
"""
Expand Down Expand Up @@ -333,33 +383,141 @@ def get_layer_cache(self, layer_idx: int):
return self.kv_cache[layer_idx]


def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
class ETCustomAttentionSinkCache(StaticCache):
"""
KV Cache with attention sink for ExecuTorch. All layers use ring buffer
with sink token preservation.

Sink tokens (first sink_size positions) are never evicted from cache.
Remaining positions use a ring buffer for sliding window.
"""

def __init__(
self,
config,
max_batch_size: int,
max_cache_len: Optional[int] = None,
sink_size: int = 4,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
):
super().__init__(
config=config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
num_heads = getattr(config, "num_key_value_heads", None) or config.num_attention_heads
self.early_initialization(
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
)
self.sink_size = sink_size
self.cache_position = None

self.kv_cache = torch.nn.ModuleList()
for layer in self.layers:
layer_cache = CustomRingKVCacheWithSink(
max_batch_size=layer.max_batch_size,
max_context_length=layer.max_cache_len,
n_heads=layer.num_heads,
head_dim=layer.head_dim,
sink_size=sink_size,
dtype=dtype,
)
self.kv_cache.append(layer_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert cache_kwargs is not None
cache_position = cache_kwargs.get("cache_position")
assert cache_position is not None
assert isinstance(cache_position, torch.Tensor)
self.cache_position = cache_position

layer_cache = self.kv_cache[layer_idx]
k_out, v_out = layer_cache.update(
input_pos=cache_position,
k_val=key_states,
v_val=value_states,
)
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if layer_idx is None:
layer_idx = 0
return self.kv_cache[layer_idx].cache_positions_manager.cache_positions.max().item() + 1

def get_layer_cache(self, layer_idx: int):
return self.kv_cache[layer_idx]


def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink=None):
"""
Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache.
This modifies the model in place.
Replace all KV caches in the module with ETCustomStaticCache, ETCustomHybridCache,
or ETCustomAttentionSinkCache.

Args:
module: The module to modify
config: The model configuration
attention_sink: Optional tuple (sink_size, window_size) for attention sink mode

Returns:
The modified module
"""
# Recursively replace KV caches
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype)
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink)


def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink=None):
"""
Helper function to recursively replace KV caches in the module.

Args:
module: The module to modify
config: The model configuration
attention_sink: Optional tuple (sink_size, window_size) for attention sink mode

Returns:
The modified module
"""
# Attention sink mode: replace static_cache with ETCustomAttentionSinkCache
if attention_sink is not None:
sink_size, window_size = attention_sink
cache_size = sink_size + window_size

if hasattr(module, "static_cache"):
sink_cache = ETCustomAttentionSinkCache(
config=config,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=cache_size,
sink_size=sink_size,
device=generation_config.cache_config.get("device"),
dtype=cache_dtype,
)
if getattr(module, "replace_cache", None) is not None:
module.replace_cache(sink_cache)
else:
module.static_cache = sink_cache
for i in range(len(sink_cache.kv_cache)):
setattr(module, f"key_cache_{i}", sink_cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", sink_cache.kv_cache[i].v_cache)
module.register_buffer(
f"cache_positions_{i}",
sink_cache.kv_cache[i].cache_positions_manager.cache_positions,
persistent=False,
)
else:
raise ValueError(
"Attention sink requires 'static_cache' attribute on module"
)
return module

# Check if module has static_cache (TorchExportableModuleWithStaticCache)
if hasattr(module, "static_cache"):
assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"
Expand Down
53 changes: 53 additions & 0 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,56 @@ def _custom_sdpa_for_ring_kv_cache(
)

return _custom_sdpa_for_ring_kv_cache


def get_custom_sdpa_for_attention_sink(
exportable_module: torch.nn.Module,
) -> Callable:
"""Create SDPA function for attention sink models.
ALL layers use ring buffer mask with sink token preservation."""

from optimum.executorch.attentions.custom_kv_cache import (
CustomRingKVCacheWithSink,
ETCustomAttentionSinkCache,
)

def _custom_sdpa_for_attention_sink(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
position_ids: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
layer_idx = module.layer_idx
assert layer_idx is not None, "layer_idx is not set."
sink_cache = exportable_module.model.static_cache
assert isinstance(sink_cache, ETCustomAttentionSinkCache), (
f"Expected ETCustomAttentionSinkCache, got {type(sink_cache)}"
)
ring_cache = sink_cache.get_layer_cache(layer_idx)
assert isinstance(ring_cache, CustomRingKVCacheWithSink), (
f"Expected CustomRingKVCacheWithSink, got {type(ring_cache)}"
)
input_pos = sink_cache.cache_position[0].item()
seqlen = query.shape[2]
attention_mask = ring_cache.create_causal_mask_for_ring_buffer(input_pos, seqlen)
kwargs.update({"is_sliding": True})
return custom_sdpa_with_start_pos_forward(
module,
query,
key,
value,
attention_mask,
position_ids,
scaling,
softcap,
head_mask,
**kwargs,
)

return _custom_sdpa_for_attention_sink
20 changes: 14 additions & 6 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough
from optimum.executorch.attentions.custom_sdpa import (
get_custom_sdpa_for_attention_sink,
get_custom_sdpa_for_ring_kv_cache,
sdpa_mask_passthrough,
)
from optimum.executorch.attentions.whisper_attention import WhisperCrossAttention

from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
Expand Down Expand Up @@ -412,13 +416,15 @@ def __init__(
use_custom_kv_cache=False,
use_custom_sdpa=False,
disable_dynamic_shapes=False,
attention_sink=None,
):
super().__init__()
self.model = model
self.config = model.config
self.use_custom_kv_cache = use_custom_kv_cache
self.use_custom_sdpa = use_custom_sdpa
self.disable_dynamic_shapes = disable_dynamic_shapes
self.attention_sink = attention_sink
self.metadata = save_config_to_constant_methods(
model.config,
generation_config=getattr(model, "generation_config", None),
Expand Down Expand Up @@ -471,16 +477,17 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module):
from transformers.modeling_utils import AttentionInterface

if self.use_custom_sdpa:
if self.use_custom_kv_cache:
if self.attention_sink is not None:
_custom_sdpa = get_custom_sdpa_for_attention_sink(exportable_module)
AttentionInterface.register("custom_sdpa_attention_sink", _custom_sdpa)
AttentionMaskInterface.register("custom_sdpa_attention_sink", sdpa_mask_passthrough)
exportable_module.model.model.config._attn_implementation = "custom_sdpa_attention_sink"
elif self.use_custom_kv_cache:
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough)
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
else:
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"

def export(
Expand All @@ -506,6 +513,7 @@ def export(
self.model.config,
self.model.generation_config,
self.model.dtype,
attention_sink=self.attention_sink,
)

with torch.no_grad():
Expand Down