Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,8 @@ def _bwd_dq_inner_split(
dp = tl.dot(do, vT)

if ENABLE_DROPOUT:
dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(dp.dtype) * dropout_scale
dp = dp * scaled_mask

# ds
delta_i = Di[:, None]
Expand Down Expand Up @@ -967,7 +968,8 @@ def _bwd_dkdv_inner_split(

# dV
if ENABLE_DROPOUT:
pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale
pT_dropout = pT * scaled_mask
dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv)
else:
dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv)
Expand All @@ -982,7 +984,8 @@ def _bwd_dkdv_inner_split(
dpT = tl.dot(v, tl.trans(do))

if ENABLE_DROPOUT:
dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale
dpT = dpT * scaled_mask

delta_i = Di[None, :]
dsT = pT * (dpT - delta_i)
Expand Down Expand Up @@ -1149,7 +1152,8 @@ def _bwd_dkdvdq_inner_atomic(

# dV
if ENABLE_DROPOUT:
pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale
pT_dropout = pT * scaled_mask
dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv)
else:
dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv)
Expand All @@ -1164,7 +1168,8 @@ def _bwd_dkdvdq_inner_atomic(
dpT = tl.dot(v, tl.trans(do))

if ENABLE_DROPOUT:
dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale
dpT = dpT * scaled_mask

delta_i = Di[None, :]
dsT = pT * (dpT - delta_i)
Expand Down Expand Up @@ -2920,7 +2925,8 @@ def _bwd_dkdv_inner(
do = tl.load(do_ptrs, mask=mask_do, other=0.0)
# Compute dV.
if ENABLE_DROPOUT:
pT_dropout = pT * dropout_mask.to(pT.dtype) * dropout_scale
scaled_mask = dropout_mask.to(pT.dtype) * dropout_scale
pT_dropout = pT * scaled_mask
dv = tl.dot(pT_dropout.to(do.type.element_ty), do, acc=dv)
else:
dv = tl.dot(pT.to(do.type.element_ty), do, acc=dv)
Expand All @@ -2936,7 +2942,8 @@ def _bwd_dkdv_inner(
else:
dpT = tl.dot(v, tl.trans(do))
if ENABLE_DROPOUT:
dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(dpT.dtype) * dropout_scale
dpT = dpT * scaled_mask
delta_i = Di[None, :]
dsT = pT * (dpT - delta_i)
if IS_FP8:
Expand Down Expand Up @@ -3098,7 +3105,8 @@ def _bwd_dq_inner(
else:
dp = tl.dot(do, vT)
if ENABLE_DROPOUT:
dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale
scaled_mask = dropout_mask.to(dp.dtype) * dropout_scale
dp = dp * scaled_mask
delta_i = Di[:, None]
ds = p * (dp - delta_i)
# Compute dQ.
Expand Down
Loading