Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph

from flash_attn import flash_attn_func

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -1880,24 +1880,26 @@ def __call__(

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

if head_dim <= 128 and attention_mask is None:
query = query.reshape(batch_size, -1, attn.heads, head_dim)
key = key.reshape(batch_size, -1, attn.heads, head_dim)
value = value.reshape(batch_size, -1, attn.heads, head_dim)
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

hidden_states = hidden_states.to(query.dtype)

# linear proj
Expand All @@ -1915,7 +1917,6 @@ def __call__(

return hidden_states


class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down