diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 93bf58e526ec..a23665970950 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,7 +23,7 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph - +from flash_attn import flash_attn_func logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1880,24 +1880,26 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + if head_dim <= 128 and attention_mask is None: + query = query.reshape(batch_size, -1, attn.heads, head_dim) + key = key.reshape(batch_size, -1, attn.heads, head_dim) + value = value.reshape(batch_size, -1, attn.heads, head_dim) + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + else: + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) # linear proj @@ -1915,7 +1917,6 @@ def __call__( return hidden_states - class StableAudioAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is