Source code for ejkernel.kernels._triton.flash_attention._triton_impl_fwd

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Any

import jax
import jax.numpy as jnp
import triton
import triton.language as tl
from jaxtyping import Array, Bool, Float, Int
from triton import Config

from ejkernel.callib import triton_call
from ejkernel.ops import BwdParams, FwdParams
from ejkernel.utils import dtype_index, get_strides

from ._utilities import (
    attention_pack_from_cu_static,
    attention_pack_with_static_shape,
    attention_unpack_with_static_shape,
    calc_bias_strides,
    padded_load,
)


[docs]def config_prune_kernel( configs: list[Config], named_args: dict[str, Any], **kwargs: Any, ) -> list[Config]: kept_configs = [] for config in configs: largerst_m = config.kwargs["BLOCK_M"] > named_args["QSeq"] largerst_n = config.kwargs["BLOCK_N"] > named_args["KSeq"] if largerst_m or largerst_n: pass else: kept_configs.append(config) if kept_configs: return kept_configs return [ Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=4), Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=2, num_stages=4), Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=2, num_stages=4), Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=2, num_stages=3), Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=2, num_stages=3), ]
@triton.jit def _attn_fwd_inner( q, m_i, me_i, k_ptrs, v_ptrs, bias_ptrs, acc_o, offs_m, offs_n, offs_d, softmax_scale, dropout_prob, dropout_seed, dropout_offs, window_left, window_right, logits_soft_cap, softmax_aux_ptrs, num_sinks, stride_kn, stride_vn, index_start_n, actual_seqlen_q, actual_seqlen_k, headdim, q_segment_ids_ptr, kv_segment_ids_ptr, stride_qsm, stride_ksn, USE_SEGMENTS: tl.constexpr, USE_DROPOUT: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, BOOL_BIAS: tl.constexpr, MASKED: tl.constexpr, SLIDING: tl.constexpr, SOFTCAP: tl.constexpr, USE_SINKS: tl.constexpr, PADDED_COLS: tl.constexpr, PADDED_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): LN2: tl.constexpr = 1.44269504089 index_start_n = tl.multiple_of(index_start_n, BLOCK_N) offset_k_ptrs = k_ptrs + index_start_n * stride_kn k = padded_load( offset_k_ptrs, index_start_n + offs_n, offs_d, PA0=PADDED_COLS, PA1=PADDED_HEADS, LA0=actual_seqlen_k, LA1=headdim, ) if BIAS_ON: if PADDED_COLS: bias = tl.load( bias_ptrs + index_start_n, mask=(offs_m[:, None] < actual_seqlen_q) & ((index_start_n + offs_n) < actual_seqlen_k)[None, :], other=0.0, ) else: bias = tl.load( bias_ptrs + index_start_n, mask=offs_m[:, None] < actual_seqlen_q, other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.bfloat16) if USE_SINKS: qk += tl.dot(q.to(tl.bfloat16), tl.trans(k.to(tl.bfloat16))) else: qk += tl.dot(q, tl.trans(k)) if PADDED_COLS: qk += tl.where( (index_start_n + offs_n)[None, :] < actual_seqlen_k, 0, float("-inf"), ) if MASKED and IS_CAUSAL: causal_mask = offs_m[:, None] >= (index_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :] qk += tl.where(causal_mask, 0, float("-inf")) if SLIDING: shift = actual_seqlen_k - actual_seqlen_q j_aligned = (index_start_n + offs_n)[None, :] - shift i_idx = offs_m[:, None] in_window = (j_aligned >= (i_idx - window_left)) & (j_aligned <= (i_idx + window_right)) qk = tl.where(in_window, qk, float("-inf")) if BIAS_ON: if BOOL_BIAS: BIG_NEG: tl.constexpr = -2147483648 qk = tl.where(bias, qk, BIG_NEG) else: qk += bias * (LN2 / softmax_scale) # Keep `attn_mask` shape stable across control-flow. Triton requires values # carried through an `if` to have identical types/shapes in both branches. # Start with a (BLOCK_M, BLOCK_N) mask (the `offs_n >= 0` term is always True). attn_mask = (offs_m[:, None] < actual_seqlen_q) & (offs_n[None, :] >= 0) if PADDED_COLS: attn_mask = attn_mask & ((index_start_n + offs_n)[None, :] < actual_seqlen_k) if MASKED and IS_CAUSAL: attn_mask = attn_mask & causal_mask if SLIDING: attn_mask = attn_mask & in_window if USE_SEGMENTS: q_ids = tl.load(q_segment_ids_ptr + offs_m * stride_qsm, mask=offs_m < actual_seqlen_q, other=-1) kv_ids = tl.load( kv_segment_ids_ptr + (index_start_n + offs_n) * stride_ksn, mask=(index_start_n + offs_n) < actual_seqlen_k, other=-1, ) seg_mask = (q_ids[:, None] == kv_ids[None, :]) & (q_ids[:, None] >= 0) attn_mask = attn_mask & seg_mask if BIAS_ON and BOOL_BIAS: attn_mask = attn_mask & bias if SOFTCAP: qk_natural = qk * (softmax_scale / LN2) x = qk_natural / logits_soft_cap exp_2x = tl.exp(2.0 * x) tanh_x = (exp_2x - 1.0) / (exp_2x + 1.0) qk = (logits_soft_cap * tanh_x) * LN2 else: qk = qk * softmax_scale qk = tl.where(attn_mask, qk, float("-inf")) if USE_SINKS: sink_offs = tl.arange(0, 16) sink_mask = sink_offs < num_sinks aux_logits = tl.load(softmax_aux_ptrs + sink_offs, mask=sink_mask, other=float("-inf")).to(tl.bfloat16) if SOFTCAP: x_aux = aux_logits / logits_soft_cap x_aux = tl.maximum(tl.minimum(x_aux, 8.0), -8.0) exp_2x = tl.exp(2.0 * x_aux) tanh_x = (exp_2x - 1.0) / (exp_2x + 1.0) aux_natural = logits_soft_cap * tanh_x aux_log2 = aux_natural * LN2 else: aux_log2 = aux_logits * LN2 qk_max = tl.max(qk, 1) aux_max = tl.max(tl.where(sink_mask, aux_log2, float("-inf"))) m_ij = tl.maximum(tl.maximum(qk_max, aux_max), me_i) m_ij_safe = tl.where(m_ij == float("-inf"), 0.0, m_ij) P_ij = tl.exp2(qk - m_ij_safe[:, None]) aux_log2_row = tl.where(sink_mask[None, :], aux_log2[None, :], float("-inf")) l_aux_row = tl.sum(tl.exp2(aux_log2_row - m_ij_safe[:, None]), axis=1) l_ij = tl.sum(P_ij, 1) + l_aux_row else: m_ij = tl.maximum(tl.max(qk, 1), me_i) m_ij_safe = tl.where(m_ij == float("-inf"), 0.0, m_ij) P_ij = tl.exp2(qk - m_ij_safe[:, None]) l_ij = tl.sum(P_ij, 1) if USE_DROPOUT: dropout_offs = dropout_offs + index_start_n dropout_mask = tl.rand(dropout_seed, dropout_offs) > dropout_prob P_ij = tl.where(dropout_mask, P_ij, 0.0) acc_o_scale = tl.where(m_ij == float("-inf"), 0.0, tl.exp2(m_i - m_ij_safe)) acc_o = acc_o * acc_o_scale[:, None] offset_v_ptrs = v_ptrs + index_start_n * stride_vn v = padded_load( offset_v_ptrs, index_start_n + offs_n, offs_d, PA0=PADDED_COLS, PA1=PADDED_HEADS, LA0=actual_seqlen_k, LA1=headdim, ) acc_o += tl.dot(P_ij.to(tl.bfloat16), v.to(tl.bfloat16)) m_i = m_ij l_i_new = tl.exp2(me_i - m_ij_safe) + l_ij me_i = m_ij_safe + tl.log2(l_i_new) return m_i, me_i, acc_o.to(tl.bfloat16) @triton.heuristics( { "EVEN_M": lambda args: args["QSeq"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["KSeq"] % args["BLOCK_N"] == 0, } ) @triton.jit def _attn_fwd( q, k, v, QSeg, KSeg, B, softmax_scale, dropout_prob, dropout_seed, logits_soft_cap, softmax_aux, num_sinks, stride_qz, stride_qm, stride_qh, stride_kz, stride_kn, stride_kh, stride_vz, stride_vn, stride_vh, stride_qsz, stride_qsm, stride_ksz, stride_ksn, stride_oz, stride_om, stride_oh, stride_bz, stride_bm, stride_bh, nheads_q, num_repeats, window_left, window_right, QSeq, cum_seqlens_q, KSeq, cum_seqlens_k, max_seqlen_q_rounded, headdim, CQSeq, CKSeq, DRuntime, Po, M, VARLEN: tl.constexpr, USE_DROPOUT: tl.constexpr, IS_CAUSAL: tl.constexpr, BIAS_ON: tl.constexpr, SLIDING: tl.constexpr, SOFTCAP: tl.constexpr, USE_SINKS: tl.constexpr, BOOL_BIAS: tl.constexpr, USE_SEGMENTS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, PADDED_HEADS: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): """Triton kernel for flash attention forward pass. Main kernel that orchestrates tiled computation of attention across blocks. Processes queries in blocks and iterates through all key/value blocks, maintaining running statistics for numerical stability. Args: q, k, v: Pointers to query, key, value tensors B: Pointer to bias tensor (optional) softmax_scale: Attention score scaling factor dropout_prob: Dropout probability dropout_seed: Random seed for dropout stride_*: Tensor strides for each dimension nheads_q: Number of query heads num_repeats: Head repeat factor for multi-query attention window_left/right: Sliding window boundaries QSeq, KSeq: Sequence lengths for queries and keys cum_seqlens_q/k: Cumulative sequence lengths for variable-length mode max_seqlen_q_rounded: Padded max sequence length headdim: Head dimension CQSeq, CKSeq: Compile-time sequence lengths DRuntime: Runtime head dimension Po: Output tensor pointer M: Log-sum-exp output pointer VARLEN: Variable-length sequence mode USE_DROPOUT: Enable dropout IS_CAUSAL: Apply causal masking BIAS_ON: Use bias tensor SLIDING: Apply sliding window BOOL_BIAS: Bias is boolean mask BLOCK_HEADDIM: Compile-time head dimension PADDED_HEADS: Head dimension needs padding EVEN_M/N: Sequence lengths are divisible by block sizes BLOCK_M/N: Block sizes for tiling """ i_start_m = tl.program_id(0) off_zh = tl.program_id(1) off_head_q = off_zh % nheads_q off_head_kv = off_head_q // num_repeats off_z = off_zh // nheads_q if VARLEN: cu_q0 = tl.load(cum_seqlens_q + off_z) cu_q1 = tl.load(cum_seqlens_q + off_z + 1) cu_k0 = tl.load(cum_seqlens_k + off_z) cu_k1 = tl.load(cum_seqlens_k + off_z + 1) actual_seqlen_q = cu_q1 - cu_q0 actual_seqlen_k = cu_k1 - cu_k0 if i_start_m * BLOCK_M >= actual_seqlen_q: return cu_seq_start_q = cu_q0 cu_seq_start_k = cu_k0 off_z = 0 else: actual_seqlen_q = QSeq actual_seqlen_k = KSeq cu_seq_start_q = 0 cu_seq_start_k = 0 LN2: tl.constexpr = 1.44269504089 softmax_scale = softmax_scale * LN2 offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0 if IS_CAUSAL and fully_masked_lines >= (i_start_m + 1) * BLOCK_M: return q_ptrs = ( q + off_z * stride_qz + off_head_q * stride_qh + cu_seq_start_q * stride_qm + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( k + off_z * stride_kz + off_head_kv * stride_kh + cu_seq_start_k * stride_kn + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( v + off_z * stride_vz + off_head_kv * stride_vh + cu_seq_start_k * stride_vn + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) q_seg_ptrs = QSeg + off_z * stride_qsz + cu_seq_start_q * stride_qsm kv_seg_ptrs = KSeg + off_z * stride_ksz + cu_seq_start_k * stride_ksn if BIAS_ON: bias_ptrs = ( B + off_z * stride_bz + off_head_kv * stride_bh + cu_seq_start_q * stride_bm + (offs_m[:, None] * stride_bm + offs_n[None, :]) ) else: bias_ptrs = None if USE_DROPOUT: dropout_off = actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z)) dropout_offs = dropout_off + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] else: dropout_offs = None if USE_SINKS: softmax_aux_ptrs = softmax_aux + off_head_q * num_sinks else: softmax_aux_ptrs = softmax_aux me_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.bfloat16) pad_rows = (not EVEN_M) or (VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q)) q = padded_load(q_ptrs, offs_m, offs_d, PA0=pad_rows, PA1=PADDED_HEADS, LA0=actual_seqlen_q, LA1=headdim) if IS_CAUSAL: end_n = tl.minimum(actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M, actual_seqlen_k) if end_n < 0: return else: end_n = actual_seqlen_k uneven_n = actual_seqlen_k % BLOCK_N != 0 attention_padding = VARLEN & uneven_n if IS_CAUSAL: first_masked_col = i_start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q elif attention_padding: first_masked_col = actual_seqlen_k else: first_masked_col = end_n nb_full_blocks = first_masked_col // BLOCK_N next_start_n = 0 if nb_full_blocks > 0: for _ in range(0, nb_full_blocks): m_i, me_i, acc_o = _attn_fwd_inner( q, m_i, me_i, k_ptrs, v_ptrs, bias_ptrs, acc_o, offs_m, offs_n, offs_d, softmax_scale, dropout_prob, dropout_seed, dropout_offs, window_left, window_right, logits_soft_cap, softmax_aux_ptrs, num_sinks, stride_kn, stride_vn, next_start_n, actual_seqlen_q, actual_seqlen_k, headdim, q_seg_ptrs, kv_seg_ptrs, stride_qsm, stride_ksn, USE_DROPOUT=USE_DROPOUT, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, MASKED=False, SLIDING=SLIDING, SOFTCAP=SOFTCAP, USE_SINKS=USE_SINKS, USE_SEGMENTS=USE_SEGMENTS, PADDED_COLS=False, PADDED_HEADS=PADDED_HEADS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) next_start_n += BLOCK_N if next_start_n < end_n: for index_start_n in range(next_start_n, end_n, BLOCK_N): pad_cols = (not EVEN_N) or VARLEN m_i, me_i, acc_o = _attn_fwd_inner( q, m_i, me_i, k_ptrs, v_ptrs, bias_ptrs, acc_o, offs_m, offs_n, offs_d, softmax_scale, dropout_prob, dropout_seed, dropout_offs, window_left, window_right, logits_soft_cap, softmax_aux_ptrs, num_sinks, stride_kn, stride_vn, index_start_n, actual_seqlen_q, actual_seqlen_k, headdim, q_seg_ptrs, kv_seg_ptrs, stride_qsm, stride_ksn, USE_DROPOUT=USE_DROPOUT, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, BOOL_BIAS=BOOL_BIAS, MASKED=True, SLIDING=SLIDING, SOFTCAP=SOFTCAP, USE_SINKS=USE_SINKS, USE_SEGMENTS=USE_SEGMENTS, PADDED_COLS=pad_cols, PADDED_HEADS=PADDED_HEADS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) invalid = me_i == float("-inf") me_i = tl.where(invalid, 0.0, me_i) m_i = tl.where(invalid, 0.0, m_i) acc_o = tl.where(invalid[:, None], 0.0, acc_o) if USE_DROPOUT: o_scale = tl.exp2((m_i - me_i) - tl.log2(1 - dropout_prob)) else: o_scale = tl.exp2(m_i - me_i) acc_o = acc_o * o_scale[:, None] if IS_CAUSAL and fully_masked_lines > i_start_m * BLOCK_M: acc_o = tl.where(offs_m[:, None] < fully_masked_lines, 0, acc_o) offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) lse_ptrs = M + off_zh * max_seqlen_q_rounded + offs_m tl.store(lse_ptrs, me_i) offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Po + off_z * stride_oz + off_head_q * stride_oh + cu_seq_start_q * stride_om + (offs_m[:, None] * stride_om + offs_d[None, :]) ) tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)) def _fwd_attention_kernel_call( q: Float[Array, "batch seq_len_q num_heads head_dim"] | None, k: Float[Array, "batch seq_len_k num_heads head_dim"] | None, v: Float[Array, "batch seq_len_k num_heads head_dim"] | None, attention_mask: Bool[Array, "batch num_heads_or_1 seq_len_q seq_len_k"] | Int[Array, "batch num_heads_or_1 seq_len_q seq_len_k"] | None = None, bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None, cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, q_segment_ids: Int[Array, "batch seq_len_q"] | None = None, kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None, ) -> tuple[Float[Array, "batch seq_len_q num_heads head_dim"], Float[Array, "batch num_heads max_seqlen_q_rounded"]]: if sliding_window is None: window_left = 0 window_right = 0 sliding_flag = False else: if isinstance(sliding_window, int): window_left = int(sliding_window) window_right = 0 if causal else int(sliding_window) else: wl, wr = sliding_window window_left = int(wl) window_right = int(wr) assert window_left >= 0 and window_right >= 0 sliding_flag = (window_left > 0) or (window_right > 0) if logits_soft_cap is None: logits_soft_cap_val = 0.0 softcap_flag = False else: logits_soft_cap_val = float(logits_soft_cap) softcap_flag = True if softmax_aux is None: use_sinks = False num_sinks_val = 0 softmax_aux_tensor = jnp.zeros((1,), dtype=q.dtype) else: use_sinks = True if softmax_aux.ndim == 1: num_sinks_val = softmax_aux.shape[0] num_heads = q.shape[2] softmax_aux_tensor = jnp.broadcast_to(softmax_aux[None, :], (num_heads, num_sinks_val)) elif softmax_aux.ndim == 2: num_sinks_val = softmax_aux.shape[1] softmax_aux_tensor = softmax_aux else: raise ValueError(f"softmax_aux must be 1D or 2D, got shape {softmax_aux.shape}") pass use_segments = (q_segment_ids is not None) or (kv_segment_ids is not None) if use_segments: if q_segment_ids is None: q_segment_ids = kv_segment_ids if kv_segment_ids is None: kv_segment_ids = q_segment_ids q_segment_ids = jnp.asarray(q_segment_ids, dtype=jnp.int32) kv_segment_ids = jnp.asarray(kv_segment_ids, dtype=jnp.int32) if q_segment_ids.ndim != 2 or kv_segment_ids.ndim != 2: raise ValueError("q_segment_ids/kv_segment_ids must be 2D int32 arrays.") if q_segment_ids.shape[0] != q.shape[0] or q_segment_ids.shape[1] != q.shape[1]: raise ValueError("q_segment_ids must have shape [batch, seq_len_q].") if kv_segment_ids.shape[0] != k.shape[0] or kv_segment_ids.shape[1] != k.shape[1]: raise ValueError("kv_segment_ids must have shape [batch, seq_len_k].") qsz, qsm = get_strides(q_segment_ids.shape) ksz, ksn = get_strides(kv_segment_ids.shape) else: q_segment_ids = jnp.zeros((1,), dtype=jnp.int32) kv_segment_ids = jnp.zeros((1,), dtype=jnp.int32) qsz = qsm = ksz = ksn = 0 if fwd_params is None: fwd_params = FwdParams() default_block_m = min(128, int(q.shape[1])) default_block_n = min(128, int(k.shape[1])) block_m = default_block_m if fwd_params.q_blocksize is None else int(fwd_params.q_blocksize) block_n = default_block_n if fwd_params.kv_blocksize is None else int(fwd_params.kv_blocksize) num_warps = 4 if fwd_params.num_warps is None else int(fwd_params.num_warps) num_stages = 2 if fwd_params.num_stages is None else int(fwd_params.num_stages) varlen_from_cu = (cum_seqlens_q is not None) and (cum_seqlens_k is not None) if varlen_from_cu: if use_segments: raise NotImplementedError("segment_ids are not supported with cum_seqlens in triton flash-attention.") assert cum_seqlens_q.dtype == jnp.int32 and cum_seqlens_k.dtype == jnp.int32 batch = q.shape[0] QSeq_max = int(q.shape[1]) KSeq_max = int(k.shape[1]) nheads_q = q.shape[2] nheads_kv = k.shape[2] head_dim = q.shape[3] assert nheads_q % nheads_kv == 0 assert q.dtype == k.dtype == v.dtype assert q.dtype in [jnp.float16, jnp.bfloat16] max_seqlen_q = QSeq_max max_seqlen_k = KSeq_max q_packed = attention_pack_from_cu_static(q, cum_seqlens_q, max_tokens=batch * QSeq_max) k_packed = attention_pack_from_cu_static(k, cum_seqlens_k, max_tokens=batch * KSeq_max) v_packed = attention_pack_from_cu_static(v, cum_seqlens_k, max_tokens=batch * KSeq_max) qz, qm, qh, _ = get_strides(q_packed.shape) kz, kn, kh, _ = get_strides(k_packed.shape) vz, vn, vh, _ = get_strides(v_packed.shape) oz, om, oh, _ = get_strides(q_packed.shape) if bias is not None: raise ValueError("Bias with VARLEN requires a packed bias; pass None or pre-pack bias.") softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale BOOL_BIAS = False max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) PADDED_HEADS = BLOCK_HEADDIM > head_dim num_repeats = nheads_q // nheads_kv metaparams = dict( VARLEN=True, USE_DROPOUT=(dropout_prob > 0), IS_CAUSAL=causal, BIAS_ON=False, SLIDING=sliding_flag, SOFTCAP=softcap_flag, USE_SINKS=use_sinks, BOOL_BIAS=BOOL_BIAS, USE_SEGMENTS=False, BLOCK_HEADDIM=BLOCK_HEADDIM, PADDED_HEADS=PADDED_HEADS, BLOCK_N=block_n, BLOCK_M=block_m, num_warps=num_warps, num_stages=num_stages, ) out_shape = [ jax.ShapeDtypeStruct(q_packed.shape, q_packed.dtype), jax.ShapeDtypeStruct((batch, nheads_q, max_seqlen_q_rounded), jnp.float32), ] out, lse = triton_call( q_packed, k_packed, v_packed, q_segment_ids, kv_segment_ids, jnp.zeros((1,), q.dtype), softmax_scale, dropout_prob, dropout_seed if dropout_seed is not None else jnp.zeros((1,), q.dtype), logits_soft_cap_val, softmax_aux_tensor, num_sinks_val, qz, qm, qh, kz, kn, kh, vz, vn, vh, qsz, qsm, ksz, ksn, oz, om, oh, 0, 0, 0, nheads_q, num_repeats, window_left, window_right, max_seqlen_q, cum_seqlens_q, max_seqlen_k, cum_seqlens_k, max_seqlen_q_rounded, head_dim, max_seqlen_q // 128, max_seqlen_k // 128, dtype_index(q_packed), kernel=_attn_fwd, out_shape=out_shape, grid=lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch * nheads_q), name="ejkernel::triton::flash_attn_fwd_varlen", **metaparams, ) out_unpacked = attention_unpack_with_static_shape(out, cum_seqlens_q, batch, QSeq_max) return out_unpacked, lse if attention_mask is not None and varlen_from_cu: varlen_mode = attention_mask.shape[0] > 1 assert bias is None, "Attention mask is not supported along with attention bias. Just use bias instead." assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq" else: varlen_mode = False batch, QSeq, nheads_q, head_dim = q.shape _, KSeq, nheads_kv, _ = k.shape expected_kv_shape = (batch, KSeq, nheads_kv, head_dim) assert k.shape == expected_kv_shape assert v.shape == expected_kv_shape assert nheads_q % nheads_kv == 0 assert q.dtype == k.dtype == v.dtype assert q.dtype in [jnp.float16, jnp.bfloat16] softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale BOOL_BIAS = False varlen_mode = varlen_mode and (batch > 1) if not varlen_mode and attention_mask is not None: assert bias is None, "when using attention mask (bool) you can't use bias" BOOL_BIAS = True bias = attention_mask.astype(jnp.bool_) if varlen_mode: cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype=jnp.int32) cum_seqlens_q = cum_seqlens_q.at[1:].set(jnp.cumsum(attention_mask.sum(axis=1, dtype="i4"), axis=0, dtype="i4")) max_seqlen_q = attention_mask.shape[1] max_seqlen_k = attention_mask.shape[1] q = attention_pack_with_static_shape(q, attention_mask) k = attention_pack_with_static_shape(k, attention_mask) v = attention_pack_with_static_shape(v, attention_mask) QSeq = q.shape[1] else: cum_seqlens_q = None max_seqlen_q = QSeq max_seqlen_k = KSeq bz, bh, bm = calc_bias_strides(bias, batch, nheads_q, QSeq, KSeq) max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) PADDED_HEADS = BLOCK_HEADDIM > head_dim num_repeats = nheads_q // nheads_kv qz, qm, qh, _ = get_strides(q.shape) oz, om, oh, _ = get_strides(q.shape) kz, kn, kh, _ = get_strides(k.shape) vz, vn, vh, _ = get_strides(v.shape) metaparams = dict( VARLEN=varlen_mode, USE_DROPOUT=(dropout_prob > 0), IS_CAUSAL=causal, BIAS_ON=(bias is not None), SLIDING=sliding_flag, SOFTCAP=softcap_flag, USE_SINKS=use_sinks, BOOL_BIAS=BOOL_BIAS, USE_SEGMENTS=use_segments, BLOCK_HEADDIM=BLOCK_HEADDIM, PADDED_HEADS=PADDED_HEADS, BLOCK_N=block_n, BLOCK_M=block_m, num_warps=num_warps, num_stages=num_stages, ) out_shape = [ jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct((batch, nheads_q, max_seqlen_q_rounded), jnp.float32), ] out, lse = triton_call( q, k, v, q_segment_ids, kv_segment_ids, bias if bias is not None else jnp.zeros((1,), q.dtype), softmax_scale, dropout_prob, dropout_seed if dropout_seed is not None else jnp.zeros((1,), q.dtype), logits_soft_cap_val, softmax_aux_tensor, num_sinks_val, qz, qm, qh, kz, kn, kh, vz, vn, vh, qsz, qsm, ksz, ksn, oz, om, oh, bz, bm, bh, nheads_q, num_repeats, window_left, window_right, QSeq, cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.int32), KSeq, jnp.zeros((1,), jnp.int32), max_seqlen_q_rounded, head_dim, max_seqlen_q // 128, max_seqlen_k // 128, dtype_index(q), kernel=_attn_fwd, out_shape=out_shape, grid=lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch * nheads_q), name="ejkernel::triton::flash_attn_fwd", **metaparams, ) if varlen_mode: out = attention_unpack_with_static_shape(out, cum_seqlens_q, *attention_mask.shape) return out, lse