# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any
import jax
import triton
import triton.language as tl
from jax import numpy as jnp
from jaxtyping import Array, Bool, Float, Int
from triton import Config
from ejkernel.callib import triton_call
from ejkernel.ops import BwdParams, FwdParams
from ejkernel.utils import dtype_index, get_sharding, get_strides
from ._utilities import attention_pack_with_static_shape, calc_bias_strides, padded_load
[docs]def config_prune_kernel(
configs: list[Config],
named_args: dict[str, Any],
**kwargs: Any,
) -> list[Config]:
"""Prune autotuning configurations for backward pass kernel.
Filters out configurations where block dimensions exceed sequence lengths.
Falls back to small default configs if all configs are pruned.
Args:
configs: List of triton autotuning configurations
named_args: Dictionary with kernel arguments including QSeq and KSeq
**kwargs: Additional unused arguments
Returns:
list[Config]: Valid configurations for the given problem size
"""
kept_configs = []
for config in configs:
largest_m = (
max(
config.kwargs["BLOCK_M1"],
config.kwargs["BLOCK_M2"],
)
> named_args["QSeq"]
)
largest_n = (
max(
config.kwargs["BLOCK_N1"],
config.kwargs["BLOCK_N2"],
)
> named_args["KSeq"]
)
if largest_m or largest_n:
pass
else:
kept_configs.append(config)
if kept_configs:
return kept_configs
return [
Config(
{
"BLOCK_M1": 32,
"BLOCK_N1": 32,
"BLOCK_M2": 32,
"BLOCK_N2": 32,
},
num_warps=4,
num_stages=0,
),
Config(
{
"BLOCK_M1": 32,
"BLOCK_N1": 32,
"BLOCK_M2": 32,
"BLOCK_N2": 32,
},
num_warps=2,
num_stages=0,
),
]
@triton.autotune(
configs=[
Config({"BLOCK_M": 16}, num_warps=4, num_stages=0),
Config({"BLOCK_M": 32}, num_warps=4, num_stages=0),
Config({"BLOCK_M": 64}, num_warps=4, num_stages=0),
Config({"BLOCK_M": 128}, num_warps=4, num_stages=0),
],
key=["CQSeq", "DRuntime"],
)
@triton.jit
def _attn_bwd_preprocess(
Po,
Do,
stride_oz,
stride_om,
stride_oh,
stride_dez,
stride_dem,
stride_deh,
nheads,
QSeq,
max_seqlen_q_rounded,
cum_seqlens_q,
headdim,
CQSeq,
DRuntime,
Delta,
VARLEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
):
"""Preprocessing kernel for backward pass gradient computation.
Computes delta values needed for efficient gradient calculation by
combining output gradients with output values.
This kernel runs before the main backward pass to prepare intermediate
values that are reused across all attention blocks.
"""
start_m = tl.program_id(0)
off_zh = tl.program_id(1)
off_z = off_zh // nheads
off_h = off_zh % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
if VARLEN:
start_seqlen_q = tl.load(cum_seqlens_q + off_z)
actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - start_seqlen_q
cu_seq_start_q = tl.load(cum_seqlens_q + off_z)
off_z = 0
else:
actual_seqlen_q = QSeq
cu_seq_start_q = 0
o_ptrs = (
Po
+ off_z * stride_oz
+ off_h * stride_oh
+ cu_seq_start_q * stride_om
+ offs_m[:, None] * stride_om
+ offs_d[None, :]
)
do_ptrs = (
Do
+ off_z * stride_dez
+ off_h * stride_deh
+ cu_seq_start_q * stride_dem
+ offs_m[:, None] * stride_dem
+ offs_d[None, :]
)
mask = (offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)
o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(Delta + off_zh * max_seqlen_q_rounded + offs_m, delta)
@triton.jit
def _attn_bwd_dkdv(
index_start_m,
k,
v,
dk,
dv,
M,
D,
offs_m,
offs_n,
offs_d,
q_ptrs,
bias_ptrs,
dropout_offs,
do_ptrs,
softmax_scale,
stride_qm,
stride_bm,
stride_dom,
actual_seqlen_q,
actual_seqlen_k,
fully_masked_lines,
headdim,
q_segment_ids_ptr,
kv_segment_ids_ptr,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
MASKED: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BIAS_ON: tl.constexpr,
BOOL_BIAS: tl.constexpr,
USE_DROPOUT: tl.constexpr,
PAD_ROWS: tl.constexpr,
PAD_COLS: tl.constexpr,
HEADS_PADDED: tl.constexpr,
SLIDING: tl.constexpr,
SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_SEGMENTS: tl.constexpr,
):
BIG_NEG: tl.constexpr = -2147483648
LN2: tl.constexpr = 1.44269504089
q_ptrs = q_ptrs + index_start_m * stride_qm
do_ptrs = do_ptrs + index_start_m * stride_dom
if BIAS_ON:
bias_ptrs = bias_ptrs + index_start_m * stride_bm
if USE_DROPOUT:
dropout_offs += index_start_m * actual_seqlen_k
offs_m_curr = index_start_m + offs_m
q = padded_load(
q_ptrs,
offs_m_curr,
offs_d,
PA0=PAD_ROWS or HEADS_PADDED,
PA1=PAD_ROWS or HEADS_PADDED,
LA0=actual_seqlen_q,
LA1=headdim,
)
me_i = tl.load(M + offs_m_curr)
if BIAS_ON:
bias = padded_load(
bias_ptrs,
offs_m_curr,
offs_n,
PA0=PAD_ROWS or HEADS_PADDED,
PA1=PAD_ROWS or HEADS_PADDED,
LA0=actual_seqlen_q,
LA1=actual_seqlen_k,
)
qk = tl.dot(q, tl.trans(k)).to(tl.float32)
if BIAS_ON:
if BOOL_BIAS:
qk = tl.where(bias, qk, BIG_NEG)
else:
qk += bias / softmax_scale
offs_n_causal = offs_n - actual_seqlen_k + actual_seqlen_q
if MASKED:
if PAD_COLS:
if IS_CAUSAL:
qk = tl.where(
tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] >= offs_n_causal[None, :],
qk,
float("-inf"),
)
else:
qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf"))
elif IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf"))
if SLIDING:
shift = actual_seqlen_k - actual_seqlen_q
j_aligned = offs_n[None, :] - shift
i_idx = offs_m_curr[:, None]
in_window = (j_aligned >= (i_idx - window_left)) & (j_aligned <= (i_idx + window_right))
qk = tl.where(in_window, qk, float("-inf"))
# Keep `attn_mask` shape stable across control-flow (see fwd kernel).
attn_mask = (offs_m_curr[:, None] < actual_seqlen_q) & (offs_n[None, :] >= 0)
if PAD_COLS:
attn_mask = attn_mask & (offs_n[None, :] < actual_seqlen_k)
if MASKED:
if PAD_COLS:
if IS_CAUSAL:
attn_mask = attn_mask & (tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] >= offs_n_causal[None, :])
else:
attn_mask = attn_mask & ((actual_seqlen_q - 1) >= offs_n_causal[None, :])
elif IS_CAUSAL:
attn_mask = attn_mask & (offs_m_curr[:, None] >= offs_n_causal[None, :])
if SLIDING:
attn_mask = attn_mask & in_window
if USE_SEGMENTS:
q_ids = tl.load(q_segment_ids_ptr + offs_m_curr * stride_qsm, mask=offs_m_curr < actual_seqlen_q, other=-1)
kv_ids = tl.load(kv_segment_ids_ptr + offs_n * stride_ksn, mask=offs_n < actual_seqlen_k, other=-1)
seg_mask = (q_ids[:, None] == kv_ids[None, :]) & (q_ids[:, None] >= 0)
attn_mask = attn_mask & seg_mask
if BIAS_ON and BOOL_BIAS:
attn_mask = attn_mask & bias
if SOFTCAP:
s = qk * softmax_scale
x = s / logits_soft_cap
exp_2x = tl.exp(2.0 * x)
tanh_x = (exp_2x - 1.0) / (exp_2x + 1.0)
qk_after = (logits_soft_cap * tanh_x) * LN2
jac = softmax_scale * (1.0 - tanh_x * tanh_x)
else:
qk_after = qk * (softmax_scale * LN2)
jac = softmax_scale
tl.debug_barrier()
p = tl.exp2(qk_after - me_i[:, None])
if MASKED and (fully_masked_lines > 0):
p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p)
p = tl.where(attn_mask, p, 0.0)
do = padded_load(
do_ptrs,
offs_m_curr,
offs_d,
PA0=PAD_ROWS,
PA1=HEADS_PADDED,
LA0=actual_seqlen_q,
LA1=headdim,
).to(tl.float32)
dv += tl.dot(tl.trans(p), do)
dp = tl.dot(do, tl.trans(v.to(tl.float32)))
Di = tl.load(D + offs_m_curr)
ds = (p * (dp - Di[:, None]) * jac).to(q.dtype)
dk += tl.dot(tl.trans(ds), q)
return dk, dv
@triton.jit
def _attn_bwd_block_dkdv(
index_start_n,
Q,
K,
V,
QSeg,
KSeg,
B,
Dropout,
Do,
Dk,
Dv,
M,
D,
softmax_scale,
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dkn,
stride_dvn,
actual_seqlen_q,
actual_seqlen_k,
headdim,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
IS_CAUSAL: tl.constexpr,
BIAS_ON: tl.constexpr,
BOOL_BIAS: tl.constexpr,
USE_DROPOUT: tl.constexpr,
PAD_COLS: tl.constexpr,
HEADS_PADDED: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
SLIDING: tl.constexpr,
SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_SEGMENTS: tl.constexpr,
):
"""Process a block of K/V positions for gradient computation.
Iterates through query blocks to accumulate gradients for a specific
block of key and value positions. Handles causal masking and sliding
window constraints.
This is the main workhorse for K/V gradient computation, called for
each K/V block position in the sequence.
"""
index_begin_m = max(index_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0
index_begin_m = (index_begin_m // BLOCK_M) * BLOCK_M
index_end_m = actual_seqlen_q
fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0
if (index_begin_m >= actual_seqlen_q) or (index_start_n >= actual_seqlen_k):
return
offs_n = index_start_n + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
dk_ptrs = Dk + (offs_n[:, None] * stride_dkn + offs_d[None, :])
dv_ptrs = Dv + (offs_n[:, None] * stride_dvn + offs_d[None, :])
do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :])
bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :]) if BIAS_ON else None
dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] if USE_DROPOUT else None
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
k = padded_load(k_ptrs, offs_n, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim)
v = padded_load(v_ptrs, offs_n, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim)
fr = max(0, index_start_n + BLOCK_N - 1 + actual_seqlen_q - actual_seqlen_k)
fb = BLOCK_M * ((min(fr, actual_seqlen_q) + BLOCK_M - 1) // BLOCK_M)
num_masked_blocks = (fb - index_begin_m) // BLOCK_M if IS_CAUSAL else 0
index_next_start_m = index_begin_m
if num_masked_blocks > 0:
for _ in range(0, num_masked_blocks):
dk, dv = _attn_bwd_dkdv(
index_next_start_m,
k,
v,
dk,
dv,
M,
D,
offs_m,
offs_n,
offs_d,
q_ptrs,
bias_ptrs,
dropout_offs,
do_ptrs,
softmax_scale,
stride_qm,
stride_bm,
stride_dom,
actual_seqlen_q,
actual_seqlen_k,
fully_masked_lines,
headdim,
QSeg,
KSeg,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
MASKED=True,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
PAD_ROWS=True,
PAD_COLS=PAD_COLS,
HEADS_PADDED=HEADS_PADDED,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
index_next_start_m += BLOCK_M
if index_next_start_m < index_end_m:
for index_start_m in range(index_next_start_m, index_end_m, BLOCK_M):
dk, dv = _attn_bwd_dkdv(
index_start_m,
k,
v,
dk,
dv,
M,
D,
offs_m,
offs_n,
offs_d,
q_ptrs,
bias_ptrs,
dropout_offs,
do_ptrs,
softmax_scale,
stride_qm,
stride_bm,
stride_dom,
actual_seqlen_q,
actual_seqlen_k,
fully_masked_lines,
headdim,
QSeg,
KSeg,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
MASKED=False,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
PAD_ROWS=True,
PAD_COLS=PAD_COLS,
HEADS_PADDED=HEADS_PADDED,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
if HEADS_PADDED:
if PAD_COLS:
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim))
else:
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
else:
if PAD_COLS:
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < actual_seqlen_k)
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < actual_seqlen_k)
else:
tl.store(dk_ptrs, dk)
tl.store(dv_ptrs, dv)
@triton.jit
def _attn_bwd_dq(
index_start_n,
q,
dq,
do,
me_i,
de_i,
offs_m,
offs_n,
offs_d,
k_ptrs,
v_ptrs,
bias_ptrs,
dropout_offs,
softmax_scale,
dropout_prob,
dropout_seed,
stride_kn,
stride_vn,
actual_seqlen_q,
actual_seqlen_k,
headdim,
q_segment_ids_ptr,
kv_segment_ids_ptr,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
MASKED: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BIAS_ON: tl.constexpr,
BOOL_BIAS: tl.constexpr,
USE_DROPOUT: tl.constexpr,
PAD_COLS: tl.constexpr,
HEADS_PADDED: tl.constexpr,
SLIDING: tl.constexpr,
SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_SEGMENTS: tl.constexpr,
):
BIG_NEG: tl.constexpr = -2147483648
LN2: tl.constexpr = 1.44269504089
k_ptrs = k_ptrs + index_start_n * stride_kn
v_ptrs = v_ptrs + index_start_n * stride_vn
offs_n_curr = index_start_n + offs_n
if BIAS_ON:
bias_ptrs += index_start_n
if USE_DROPOUT:
dropout_offs += index_start_n
k = padded_load(k_ptrs, offs_n_curr, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim)
v = padded_load(v_ptrs, offs_n_curr, offs_d, PA0=PAD_COLS, PA1=HEADS_PADDED, LA0=actual_seqlen_k, LA1=headdim)
if BIAS_ON:
bias = padded_load(
bias_ptrs, offs_m, offs_n_curr, PA0=True, PA1=PAD_COLS, LA0=actual_seqlen_q, LA1=actual_seqlen_k
)
qk = tl.dot(q, tl.trans(k))
if BIAS_ON:
if BOOL_BIAS:
qk = tl.where(bias, qk, BIG_NEG)
else:
qk += bias / softmax_scale
offs_n_causal = offs_n_curr - actual_seqlen_k + actual_seqlen_q
if MASKED:
if PAD_COLS:
if IS_CAUSAL:
qk = tl.where(
tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :],
qk,
float("-inf"),
)
else:
qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf"))
elif IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf"))
if SLIDING:
shift = actual_seqlen_k - actual_seqlen_q
j_aligned = offs_n_curr[None, :] - shift
i_idx = offs_m[:, None]
in_window = (j_aligned >= (i_idx - window_left)) & (j_aligned <= (i_idx + window_right))
qk = tl.where(in_window, qk, float("-inf"))
# Keep `attn_mask` shape stable across control-flow (see fwd kernel).
attn_mask = (offs_m[:, None] < actual_seqlen_q) & (offs_n_curr[None, :] >= 0)
if PAD_COLS:
attn_mask = attn_mask & (offs_n_curr[None, :] < actual_seqlen_k)
if MASKED:
if PAD_COLS:
if IS_CAUSAL:
attn_mask = attn_mask & (tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :])
else:
attn_mask = attn_mask & ((actual_seqlen_q - 1) >= offs_n_causal[None, :])
elif IS_CAUSAL:
attn_mask = attn_mask & (offs_m[:, None] >= offs_n_causal[None, :])
if SLIDING:
attn_mask = attn_mask & in_window
if USE_SEGMENTS:
q_ids = tl.load(q_segment_ids_ptr + offs_m * stride_qsm, mask=offs_m < actual_seqlen_q, other=-1)
kv_ids = tl.load(kv_segment_ids_ptr + offs_n_curr * stride_ksn, mask=offs_n_curr < actual_seqlen_k, other=-1)
seg_mask = (q_ids[:, None] == kv_ids[None, :]) & (q_ids[:, None] >= 0)
attn_mask = attn_mask & seg_mask
if BIAS_ON and BOOL_BIAS:
attn_mask = attn_mask & bias
if SOFTCAP:
s = qk * softmax_scale
x = s / logits_soft_cap
exp_2x = tl.exp(2.0 * x)
tanh_x = (exp_2x - 1.0) / (exp_2x + 1.0)
qk_after = (logits_soft_cap * tanh_x) * LN2
jac = softmax_scale * (1.0 - tanh_x * tanh_x)
else:
qk_after = qk * (softmax_scale * LN2)
jac = softmax_scale
tl.debug_barrier()
p = tl.exp2(qk_after - me_i[:, None])
p = tl.where(attn_mask, p, 0.0)
dp = tl.dot(do, tl.trans(v.to(tl.float32)))
ds = (p * (dp - de_i[:, None]) * jac).to(q.dtype)
dq += tl.dot(ds, k)
return dq
@triton.jit
def _attn_bwd_block_dq(
index_start_m,
Q,
K,
V,
QSeg,
KSeg,
B,
Dropout,
Do,
Dq,
M,
D,
softmax_scale,
dropout_prob,
dropout_seed,
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
actual_seqlen_q,
actual_seqlen_k,
headdim,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BIAS_ON: tl.constexpr,
BOOL_BIAS: tl.constexpr,
USE_DROPOUT: tl.constexpr,
PAD_ROWS: tl.constexpr,
HEADS_PADDED: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_N: tl.constexpr,
SLIDING: tl.constexpr,
SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_SEGMENTS: tl.constexpr,
):
if IS_CAUSAL:
index_end_n = min(
actual_seqlen_k - actual_seqlen_q + index_start_m + BLOCK_M,
actual_seqlen_k,
)
if index_end_n < 0:
return
else:
index_end_n = actual_seqlen_k
fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0
mask_reached = fully_masked_lines >= index_start_m + BLOCK_M
if (index_start_m >= actual_seqlen_q) or mask_reached:
return
offs_m = tl.arange(0, BLOCK_M) + index_start_m
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
dq_ptrs = Dq + (offs_m[:, None] * stride_dqm + offs_d[None, :])
do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :])
if BIAS_ON:
bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :])
else:
bias_ptrs = None
if USE_DROPOUT:
dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :]
else:
dropout_offs = None
dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
q = padded_load(q_ptrs, offs_m, offs_d, PA0=PAD_ROWS, PA1=HEADS_PADDED, LA0=actual_seqlen_q, LA1=headdim)
do = padded_load(do_ptrs, offs_m, offs_d, PA0=PAD_ROWS, PA1=HEADS_PADDED, LA0=actual_seqlen_q, LA1=headdim).to(
tl.float32
)
me_i = tl.load(M + offs_m)
de_i = tl.load(D + offs_m)
uneven_n = actual_seqlen_k % BLOCK_N != 0
attention_padding = VARLEN & uneven_n
if IS_CAUSAL:
first_masked_col = index_start_m + 1 + actual_seqlen_k - actual_seqlen_q
elif attention_padding:
first_masked_col = actual_seqlen_k
else:
first_masked_col = index_end_n
nb_full_blocks = first_masked_col // BLOCK_N
index_next_start_n = 0
if nb_full_blocks > 0:
for _ in range(0, nb_full_blocks):
index_next_start_n = tl.multiple_of(index_next_start_n, BLOCK_N)
dq = _attn_bwd_dq(
index_next_start_n,
q,
dq,
do,
me_i,
de_i,
offs_m,
offs_n,
offs_d,
k_ptrs,
v_ptrs,
bias_ptrs,
dropout_offs,
softmax_scale,
dropout_prob,
dropout_seed,
stride_kn,
stride_vn,
actual_seqlen_q,
actual_seqlen_k,
headdim,
QSeg,
KSeg,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
MASKED=False,
PAD_COLS=False,
HEADS_PADDED=HEADS_PADDED,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
index_next_start_n += BLOCK_N
if index_next_start_n < index_end_n:
for index_start_n in range(index_next_start_n, index_end_n, BLOCK_N):
pad_cols = (not EVEN_N) or (VARLEN and (index_start_n + BLOCK_N > actual_seqlen_k))
dq = _attn_bwd_dq(
index_start_n,
q,
dq,
do,
me_i,
de_i,
offs_m,
offs_n,
offs_d,
k_ptrs,
v_ptrs,
bias_ptrs,
dropout_offs,
softmax_scale,
dropout_prob,
dropout_seed,
stride_kn,
stride_vn,
actual_seqlen_q,
actual_seqlen_k,
headdim,
QSeg,
KSeg,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
MASKED=True,
PAD_COLS=pad_cols,
HEADS_PADDED=HEADS_PADDED,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
if fully_masked_lines > 0:
dq = tl.where(offs_m[:, None] < fully_masked_lines, 0, dq)
if HEADS_PADDED:
if PAD_ROWS:
tl.store(dq_ptrs, dq, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim))
else:
tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
else:
if PAD_ROWS:
tl.store(dq_ptrs, dq, mask=offs_m[:, None] < actual_seqlen_q)
else:
tl.store(dq_ptrs, dq)
@triton.autotune(
configs=[
Config(
{"BLOCK_M1": 16, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 16},
num_warps=2,
num_stages=0,
),
Config(
{"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32},
num_warps=2,
num_stages=0,
),
Config(
{"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32},
num_warps=2,
num_stages=0,
),
Config(
{"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64},
num_warps=2,
num_stages=0,
),
Config(
{"BLOCK_M1": 16, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 16},
num_warps=4,
num_stages=0,
),
Config(
{"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32},
num_warps=4,
num_stages=0,
),
Config(
{"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32},
num_warps=4,
num_stages=0,
),
Config(
{"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64},
num_warps=4,
num_stages=0,
),
],
key=[
"CQSeq",
"CKSeq",
"DRuntime",
"VARLEN",
"USE_DROPOUT",
"IS_CAUSAL",
"BIAS_ON",
"BLOCK_HEADDIM",
"SLIDING",
],
prune_configs_by={"early_config_prune": config_prune_kernel},
)
@triton.heuristics(
{
"EVEN_M1": lambda args: args["QSeq"] % args["BLOCK_M1"] == 0,
"EVEN_N1": lambda args: args["KSeq"] % args["BLOCK_N1"] == 0,
"EVEN_M2": lambda args: args["QSeq"] % args["BLOCK_M2"] == 0,
"EVEN_N2": lambda args: args["KSeq"] % args["BLOCK_N2"] == 0,
"HEADS_PADDED": lambda args: args["headdim"] != args["BLOCK_HEADDIM"],
"NUM_BLOCKS_KV": lambda args: math.ceil(args["KSeq"] / args["BLOCK_N1"]),
}
)
@triton.jit
def _attn_bwd(
Q,
K,
V,
QSeg,
KSeg,
B,
Do,
M,
D,
softmax_scale,
dropout_prob,
dropout_seed,
stride_qz,
stride_qm,
stride_qh,
stride_kz,
stride_kn,
stride_kh,
stride_vz,
stride_vn,
stride_vh,
stride_qsz,
stride_qsm,
stride_ksz,
stride_ksn,
stride_bz,
stride_bm,
stride_bh,
stride_doz,
stride_dom,
stride_doh,
stride_dqz,
stride_dqm,
stride_dqh,
stride_dkz,
stride_dkn,
stride_dkh,
stride_dvz,
stride_dvn,
stride_dvh,
nheads_q,
num_repeats,
window_left,
window_right,
QSeq,
cum_seqlens_q,
KSeq,
cum_seqlens_k,
seqlen_q_rounded,
headdim,
CQSeq,
CKSeq,
DRuntime,
logits_soft_cap,
softmax_aux,
num_sinks,
Dq,
Dk,
Dv,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BIAS_ON: tl.constexpr,
BOOL_BIAS: tl.constexpr,
USE_SEGMENTS: tl.constexpr,
USE_DROPOUT: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M1: tl.constexpr,
EVEN_N1: tl.constexpr,
EVEN_M2: tl.constexpr,
EVEN_N2: tl.constexpr,
NUM_BLOCKS_KV: tl.constexpr,
HEADS_PADDED: tl.constexpr,
BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
SLIDING: tl.constexpr,
SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
):
"""Main backward pass kernel for flash attention gradient computation.
Orchestrates the computation of gradients for Q, K, and V tensors using
a two-phase approach: first computing Q gradients, then K/V gradients.
This kernel is the entry point for the backward pass, managing work
distribution across thread blocks for efficient gradient computation.
"""
pid = tl.program_id(0)
off_zh = tl.program_id(1)
off_z = off_zh // nheads_q
off_head_q = off_zh % nheads_q
off_head_kv = off_head_q // num_repeats
if VARLEN:
cu_seq_start_q = tl.load(cum_seqlens_q + off_z)
cu_seq_start_k = tl.load(cum_seqlens_k + off_z)
actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q
actual_seqlen_k = tl.load(cum_seqlens_k + off_z + 1) - cu_seq_start_k
off_z = 0
else:
cu_seq_start_q = 0
cu_seq_start_k = 0
actual_seqlen_q = QSeq
actual_seqlen_k = KSeq
Q += off_z * stride_qz + off_head_q * stride_qh + cu_seq_start_q * stride_qm
K += off_z * stride_kz + off_head_kv * stride_kh + cu_seq_start_k * stride_kn
V += off_z * stride_vz + off_head_kv * stride_vh + cu_seq_start_k * stride_vn
QSeg += off_z * stride_qsz + cu_seq_start_q * stride_qsm
KSeg += off_z * stride_ksz + cu_seq_start_k * stride_ksn
Do += off_z * stride_doz + off_head_q * stride_doh + cu_seq_start_q * stride_dom
Dq += off_z * stride_dqz + off_head_q * stride_dqh + cu_seq_start_q * stride_dqm
Dk += off_z * stride_dkz + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn
Dv += off_z * stride_dvz + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn
if BIAS_ON:
B += off_z * stride_bz + off_head_q * stride_bh + cu_seq_start_q * stride_bm
Dropout = (
actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z)) if USE_DROPOUT else None
)
if USE_SINKS:
softmax_aux_ptrs = softmax_aux + off_head_q * num_sinks
else:
softmax_aux_ptrs = softmax_aux
D += off_zh * seqlen_q_rounded
M += off_zh * seqlen_q_rounded
if pid < NUM_BLOCKS_KV:
i_start_n = pid
pad_cols = (not EVEN_N1) or (VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k))
_attn_bwd_block_dkdv(
i_start_n * BLOCK_N1,
Q,
K,
V,
QSeg,
KSeg,
B,
Dropout,
Do,
Dk,
Dv,
M,
D,
softmax_scale,
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dkn,
stride_dvn,
actual_seqlen_q,
actual_seqlen_k,
headdim,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
PAD_COLS=pad_cols,
HEADS_PADDED=HEADS_PADDED,
BLOCK_M=BLOCK_M1,
BLOCK_N=BLOCK_N1,
BLOCK_HEADDIM=BLOCK_HEADDIM,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
else:
i_start_m = pid - NUM_BLOCKS_KV
pad_rows = (not EVEN_M2) or (VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q))
_attn_bwd_block_dq(
i_start_m * BLOCK_M2,
Q,
K,
V,
QSeg,
KSeg,
B,
Dropout,
Do,
Dq,
M,
D,
softmax_scale,
dropout_prob,
dropout_seed,
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
actual_seqlen_q,
actual_seqlen_k,
headdim,
stride_qsm,
stride_ksn,
window_left,
window_right,
logits_soft_cap,
softmax_aux_ptrs,
num_sinks,
VARLEN=VARLEN,
IS_CAUSAL=IS_CAUSAL,
BIAS_ON=BIAS_ON,
BOOL_BIAS=BOOL_BIAS,
USE_DROPOUT=USE_DROPOUT,
PAD_ROWS=pad_rows,
HEADS_PADDED=HEADS_PADDED,
BLOCK_M=BLOCK_M2,
BLOCK_N=BLOCK_N2,
BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_N=EVEN_N2,
SLIDING=SLIDING,
SOFTCAP=SOFTCAP,
USE_SINKS=USE_SINKS,
USE_SEGMENTS=USE_SEGMENTS,
)
def _bwd_attention_kernel_call(
dO: Float[Array, "batch seq_len_q num_heads head_dim"],
q: Float[Array, "batch seq_len_q num_heads head_dim"],
k: Float[Array, "batch seq_len_k num_heads head_dim"],
v: Float[Array, "batch seq_len_k num_heads head_dim"],
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None,
attention_mask: Bool[Array, "batch seq_len"] | None,
o: Float[Array, "batch seq_len_q num_heads head_dim"],
M: Float[Array, "batch num_heads max_seqlen_q_rounded"],
dropout_prob: float,
causal: bool,
softmax_scale: float | None,
dropout_seed: int | None,
fwd_params: FwdParams | None = None,
bwd_params: BwdParams | None = None,
cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None,
cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None,
sliding_window: int | tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
softmax_aux: Float[Array, "num_sinks"] | Float[Array, "num_heads num_sinks"] | None = None,
q_segment_ids: Int[Array, "batch seq_len_q"] | None = None,
kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None,
) -> tuple[
Float[Array, "batch seq_len_q num_heads head_dim"],
Float[Array, "batch seq_len_k num_heads head_dim"],
Float[Array, "batch seq_len_k num_heads head_dim"],
]:
"""Execute flash attention backward pass using Triton kernels.
Prepares inputs and launches Triton kernels for gradient computation.
Handles preprocessing and main backward kernel execution.
Args:
dO: Gradient of loss with respect to attention output
q, k, v: Query, key, value tensors from forward pass
bias: Optional attention bias from forward pass
attention_mask: Legacy mask parameter
o: Output from forward pass
M: Log-sum-exp values from forward pass
dropout_prob: Dropout probability used in forward pass
causal: Whether causal masking was applied
softmax_scale: Attention score scaling factor
dropout_seed: Random seed for dropout
cum_seqlens_q/k: Cumulative sequence lengths for variable-length mode
sliding_window: Local attention window size
Returns:
tuple: Gradients (dq, dk, dv) for query, key, and value tensors
"""
if sliding_window is None:
window_left = 0
window_right = 0
sliding_flag = False
else:
if isinstance(sliding_window, int):
window_left = int(sliding_window)
window_right = 0 if causal else int(sliding_window)
else:
wl, wr = sliding_window
window_left = int(wl)
window_right = int(wr)
assert window_left >= 0 and window_right >= 0
sliding_flag = (window_left > 0) or (window_right > 0)
if logits_soft_cap is None:
logits_soft_cap_val = 0.0
softcap_flag = False
else:
logits_soft_cap_val = float(logits_soft_cap)
softcap_flag = True
if softmax_aux is None:
use_sinks = False
num_sinks_val = 0
softmax_aux_tensor = jnp.zeros((1,), dtype=q.dtype)
else:
use_sinks = True
if softmax_aux.ndim == 1:
num_sinks_val = softmax_aux.shape[0]
num_heads = q.shape[2]
softmax_aux_tensor = jnp.broadcast_to(softmax_aux[None, :], (num_heads, num_sinks_val))
elif softmax_aux.ndim == 2:
num_sinks_val = softmax_aux.shape[1]
softmax_aux_tensor = softmax_aux
else:
raise ValueError(f"softmax_aux must be 1D or 2D, got shape {softmax_aux.shape}")
head_dim = q.shape[-1]
softmax_scale = 1.0 / math.sqrt(float(head_dim)) if softmax_scale is None else softmax_scale
softmax_aux_tensor = softmax_aux_tensor * softmax_scale
if softcap_flag:
softmax_aux_tensor = logits_soft_cap_val * jnp.tanh(softmax_aux_tensor / logits_soft_cap_val)
use_segments = (q_segment_ids is not None) or (kv_segment_ids is not None)
if use_segments:
if q_segment_ids is None:
q_segment_ids = kv_segment_ids
if kv_segment_ids is None:
kv_segment_ids = q_segment_ids
q_segment_ids = jnp.asarray(q_segment_ids, dtype=jnp.int32)
kv_segment_ids = jnp.asarray(kv_segment_ids, dtype=jnp.int32)
if q_segment_ids.ndim != 2 or kv_segment_ids.ndim != 2:
raise ValueError("q_segment_ids/kv_segment_ids must be 2D int32 arrays.")
if q_segment_ids.shape[0] != q.shape[0] or q_segment_ids.shape[1] != q.shape[1]:
raise ValueError("q_segment_ids must have shape [batch, seq_len_q].")
if kv_segment_ids.shape[0] != k.shape[0] or kv_segment_ids.shape[1] != k.shape[1]:
raise ValueError("kv_segment_ids must have shape [batch, seq_len_k].")
qsz, qsm = get_strides(q_segment_ids.shape)
ksz, ksn = get_strides(kv_segment_ids.shape)
else:
q_segment_ids = jnp.zeros((1,), dtype=jnp.int32)
kv_segment_ids = jnp.zeros((1,), dtype=jnp.int32)
qsz = qsm = ksz = ksn = 0
varlen_from_cu = (cum_seqlens_q is not None) and (cum_seqlens_k is not None)
if varlen_from_cu:
if use_segments:
raise NotImplementedError("segment_ids are not supported with cum_seqlens in triton flash-attention.")
assert cum_seqlens_q.dtype == jnp.int32 and cum_seqlens_k.dtype == jnp.int32
batch_size, QSeq_max, nheads_q, head_dim = q.shape
_, KSeq_max, nheads_kv, _ = k.shape
assert nheads_q % nheads_kv == 0
num_repeats = nheads_q // nheads_kv
BOOL_BIAS = False
softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
max_seqlen_q = QSeq_max
max_seqlen_k = KSeq_max
max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
from ._utilities import attention_pack_from_cu_static, attention_unpack_with_static_shape
q_p = attention_pack_from_cu_static(q, cum_seqlens_q, max_tokens=batch_size * QSeq_max)
k_p = attention_pack_from_cu_static(k, cum_seqlens_k, max_tokens=batch_size * KSeq_max)
v_p = attention_pack_from_cu_static(v, cum_seqlens_k, max_tokens=batch_size * KSeq_max)
o_p = attention_pack_from_cu_static(o, cum_seqlens_q, max_tokens=batch_size * QSeq_max)
dO_p = attention_pack_from_cu_static(dO, cum_seqlens_q, max_tokens=batch_size * QSeq_max)
oz, om, oh, _ = get_strides(o_p)
doz, dom, doh, _ = get_strides(dO_p)
qz, qm, qh, _ = get_strides(q_p)
kz, kn, kh, _ = get_strides(k_p)
vz, vn, vh, _ = get_strides(v_p)
(delta,) = triton_call(
o_p,
dO_p,
oz,
om,
oh,
doz,
dom,
doh,
nheads_q,
max_seqlen_q,
max_seqlen_q_rounded,
cum_seqlens_q,
head_dim,
max_seqlen_q // 32,
dtype_index(q_p),
VARLEN=True,
BLOCK_HEADDIM=BLOCK_HEADDIM,
out_shape=[jax.ShapeDtypeStruct(shape=M.shape, dtype="f4", sharding=get_sharding(M))],
grid=lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch_size * nheads_q),
kernel=_attn_bwd_preprocess,
name="ejkernel::triton::flash_attn_bwd_preprocess",
)
bz = bm = bh = 0
dq, dk, dv = triton_call(
q_p,
k_p,
v_p,
q_segment_ids,
kv_segment_ids,
jnp.zeros((1,), q.dtype),
dO_p,
M,
delta,
softmax_scale,
dropout_prob,
dropout_seed if dropout_seed is not None else jnp.zeros((1,), q.dtype),
qz,
qm,
qh,
kz,
kn,
kh,
vz,
vn,
vh,
qsz,
qsm,
ksz,
ksn,
bz,
bm,
bh,
doz,
dom,
doh,
qz,
qm,
qh,
kz,
kn,
kh,
vz,
vn,
vh,
nheads_q,
num_repeats,
window_left,
window_right,
max_seqlen_q,
cum_seqlens_q,
max_seqlen_k,
cum_seqlens_k,
max_seqlen_q_rounded,
head_dim,
max_seqlen_q // 32,
max_seqlen_k // 32,
dtype_index(q_p),
logits_soft_cap_val,
softmax_aux_tensor,
num_sinks_val,
BIAS_ON=False,
VARLEN=True,
IS_CAUSAL=causal,
USE_DROPOUT=(dropout_prob > 0),
BLOCK_HEADDIM=BLOCK_HEADDIM,
BOOL_BIAS=False,
USE_SEGMENTS=False,
SLIDING=sliding_flag,
SOFTCAP=softcap_flag,
USE_SINKS=use_sinks,
kernel=_attn_bwd,
grid=lambda META: (
triton.cdiv(max_seqlen_k, META["BLOCK_N1"]) + triton.cdiv(max_seqlen_q, META["BLOCK_M2"]),
batch_size * nheads_q,
),
out_shape=[
jax.ShapeDtypeStruct(shape=q_p.shape, dtype="f4", sharding=get_sharding(q)),
jax.ShapeDtypeStruct(shape=(*k_p.shape[:2], q_p.shape[2], k_p.shape[3]), dtype=k.dtype),
jax.ShapeDtypeStruct(shape=(*v_p.shape[:2], q_p.shape[2], v_p.shape[3]), dtype=v.dtype),
],
name="ejkernel::triton::flash_attn_bwd",
)
if num_repeats > 1:
dk = dk.reshape(dk.shape[0], dk.shape[1], (nheads_q // num_repeats), num_repeats, -1).sum(axis=3)
dv = dv.reshape(dv.shape[0], dv.shape[1], (nheads_q // num_repeats), num_repeats, -1).sum(axis=3)
dq = attention_unpack_with_static_shape(dq, cum_seqlens_q, batch_size, QSeq_max)
dk = attention_unpack_with_static_shape(dk, cum_seqlens_k, batch_size, KSeq_max)
dv = attention_unpack_with_static_shape(dv, cum_seqlens_k, batch_size, KSeq_max)
return dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)
if attention_mask is not None and varlen_from_cu:
assert bias is None, "mask + bias not supported; use bias alone or pack bias."
assert q.shape[1] == k.shape[1], "mask varlen path supports QSeq == KSeq only."
varlen_mode = attention_mask.shape[0] > 1
useless_padding = attention_mask.shape[1] - attention_mask.sum(-1).max().item()
if useless_padding > 0:
dO = dO[:, :-useless_padding]
q = q[:, :-useless_padding]
k = k[:, :-useless_padding]
v = v[:, :-useless_padding]
attention_mask = attention_mask[:, :-useless_padding]
o = o[:, :-useless_padding]
else:
varlen_mode = False
useless_padding = 0
batch_size, QSeq, nheads_q, head_dim = q.shape
_, KSeq, nheads_kv, _ = k.shape
max_seqlen_q_rounded = math.ceil(QSeq / 128) * 128
softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
assert nheads_q % nheads_kv == 0
assert M.shape == (batch_size, nheads_q, max_seqlen_q_rounded)
BOOL_BIAS = False
if not varlen_mode and attention_mask is not None:
assert bias is None, "mask + bias not supported"
BOOL_BIAS = True
bias = attention_mask.astype(jnp.bool_)
if varlen_mode:
cum_seqlens_q = jnp.zeros((attention_mask.shape[0] + 1,), dtype=jnp.int32)
cum_seqlens_k = jnp.zeros((attention_mask.shape[0] + 1,), dtype=jnp.int32)
lengths = attention_mask.sum(axis=1, dtype="i4")
cum_seqlens_q = cum_seqlens_q.at[1:].set(jnp.cumsum(lengths, axis=0, dtype="i4"))
cum_seqlens_k = cum_seqlens_k.at[1:].set(jnp.cumsum(lengths, axis=0, dtype="i4"))
max_seqlen_q = attention_mask.shape[1]
max_seqlen_k = attention_mask.shape[1]
dO = attention_pack_with_static_shape(dO, attention_mask)
q = attention_pack_with_static_shape(q, attention_mask)
k = attention_pack_with_static_shape(k, attention_mask)
v = attention_pack_with_static_shape(v, attention_mask)
o = attention_pack_with_static_shape(o, attention_mask)
QSeq = q.shape[1]
KSeq = k.shape[1]
else:
cum_seqlens_q = None
cum_seqlens_k = None
max_seqlen_q = QSeq
max_seqlen_k = KSeq
bz, bh, bm = calc_bias_strides(bias, batch_size, nheads_q, QSeq, KSeq)
num_repeats = nheads_q // nheads_kv
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
oz, om, oh, _ = get_strides(o)
doz, dom, doh, _ = get_strides(dO)
qz, qm, qh, _ = get_strides(q)
kz, kn, kh, _ = get_strides(k)
vz, vn, vh, _ = get_strides(v)
(delta,) = triton_call(
o,
dO,
oz,
om,
oh,
doz,
dom,
doh,
nheads_q,
QSeq,
max_seqlen_q_rounded,
cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.int32),
head_dim,
max_seqlen_q // 32,
dtype_index(q),
VARLEN=varlen_mode,
BLOCK_HEADDIM=BLOCK_HEADDIM,
out_shape=[jax.ShapeDtypeStruct(shape=M.shape, dtype="f4", sharding=get_sharding(M))],
grid=lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch_size * nheads_q),
kernel=_attn_bwd_preprocess,
name="ejkernel::triton::flash_attn_bwd_preprocess",
)
dq, dk, dv = triton_call(
q,
k,
v,
q_segment_ids,
kv_segment_ids,
bias if bias is not None else jnp.zeros((1,), jnp.float16),
dO,
M,
delta,
softmax_scale,
dropout_prob,
dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16),
qz,
qm,
qh,
kz,
kn,
kh,
vz,
vn,
vh,
qsz,
qsm,
ksz,
ksn,
bz,
bm,
bh,
doz,
dom,
doh,
qz,
qm,
qh,
kz,
kn,
kh,
vz,
vn,
vh,
nheads_q,
num_repeats,
window_left,
window_right,
QSeq,
cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.int32),
KSeq,
cum_seqlens_k if cum_seqlens_k is not None else jnp.zeros((1,), jnp.int32),
max_seqlen_q_rounded,
head_dim,
max_seqlen_q // 32,
max_seqlen_k // 32,
dtype_index(q),
logits_soft_cap_val,
softmax_aux_tensor,
num_sinks_val,
BIAS_ON=(bias is not None),
VARLEN=varlen_mode,
IS_CAUSAL=causal,
USE_DROPOUT=(dropout_prob > 0),
BLOCK_HEADDIM=BLOCK_HEADDIM,
BOOL_BIAS=BOOL_BIAS,
USE_SEGMENTS=use_segments,
SLIDING=sliding_flag,
SOFTCAP=softcap_flag,
USE_SINKS=use_sinks,
kernel=_attn_bwd,
grid=lambda META: (
triton.cdiv(KSeq, META["BLOCK_N1"]) + triton.cdiv(QSeq, META["BLOCK_M2"]),
batch_size * nheads_q,
),
out_shape=[
jax.ShapeDtypeStruct(shape=q.shape, dtype="f4", sharding=get_sharding(q)),
jax.ShapeDtypeStruct(shape=(k.shape[0], k.shape[1], q.shape[2], k.shape[3]), dtype="f4"),
jax.ShapeDtypeStruct(shape=(v.shape[0], v.shape[1], q.shape[2], v.shape[3]), dtype="f4"),
],
name="ejkernel::triton::flash_attn_bwd",
)
if num_repeats > 1:
dk = dk.reshape(dk.shape[0], dk.shape[1], (nheads_q // num_repeats), num_repeats, -1).sum(axis=3)
dv = dv.reshape(dv.shape[0], dv.shape[1], (nheads_q // num_repeats), num_repeats, -1).sum(axis=3)
if varlen_mode:
dq = attention_unpack_with_static_shape(dq, cum_seqlens_q, batch_size, max_seqlen_q)
dk = attention_unpack_with_static_shape(dk, cum_seqlens_k, batch_size, max_seqlen_k)
dv = attention_unpack_with_static_shape(dv, cum_seqlens_k, batch_size, max_seqlen_k)
if useless_padding > 0:
dq = jnp.pad(dq, ((0, useless_padding), (0, 0), (0, 0)))
dk = jnp.pad(dk, ((0, useless_padding), (0, 0), (0, 0)))
dv = jnp.pad(dv, ((0, useless_padding), (0, 0), (0, 0)))
return dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)