diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py index e5f7fbdec1..3650af18b7 100755 --- a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py @@ -827,7 +827,8 @@ def _bwd_dq_inner_split( dp = tl.dot(do, vT) if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(dp.dtype) * dropout_scale + dp = dp * scaled_mask # ds delta_i = Di[:, None] @@ -967,7 +968,8 @@ def _bwd_dkdv_inner_split( # dV if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale + pT_dropout = pT * scaled_mask dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv) else: dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv) @@ -982,7 +984,8 @@ def _bwd_dkdv_inner_split( dpT = tl.dot(v, tl.trans(do)) if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale + dpT = dpT * scaled_mask delta_i = Di[None, :] dsT = pT * (dpT - delta_i) @@ -1149,7 +1152,8 @@ def _bwd_dkdvdq_inner_atomic( # dV if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale + pT_dropout = pT * scaled_mask dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv) else: dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv) @@ -1164,7 +1168,8 @@ def _bwd_dkdvdq_inner_atomic( dpT = tl.dot(v, tl.trans(do)) if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale + dpT = dpT * scaled_mask delta_i = Di[None, :] dsT = pT * (dpT - delta_i) @@ -2920,7 +2925,8 @@ def _bwd_dkdv_inner( do = tl.load(do_ptrs, mask=mask_do, other=0.0) # Compute dV. if ENABLE_DROPOUT: - pT_dropout = pT * dropout_mask.to(pT.dtype) * dropout_scale + scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale + pT_dropout = pT * scaled_mask dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv) else: dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv) @@ -2936,7 +2942,8 @@ def _bwd_dkdv_inner( else: dpT = tl.dot(v, tl.trans(do)) if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale + dpT = dpT * scaled_mask delta_i = Di[None, :] dsT = pT * (dpT - delta_i) if IS_FP8: @@ -3098,7 +3105,8 @@ def _bwd_dq_inner( else: dp = tl.dot(do, vT) if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + scaled_mask = dropout_mask.to(dp.dtype) * dropout_scale + dp = dp * scaled_mask delta_i = Di[:, None] ds = p * (dp - delta_i) # Compute dQ.