Source code for ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Implementation of Sparse Flash Attention, a.k.a. "Splash" attention."""

from __future__ import annotations

import dataclasses
import enum
import functools
import typing
from collections.abc import Mapping
from typing import Any, Literal, NamedTuple, Union, overload

import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
from beartype import beartype
from beartype.typing import Callable
from jax import ad_checkpoint, lax, tree_util
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jaxtyping import Array, Bool, Float, Int

from ejkernel.callib import ejit
from ejkernel.kernels._registry import Backend, Platform, kernel_registry
from ejkernel.ops import BwdParams, FwdParams

from . import _info as mask_info_lib
from . import _masks as mask_lib
from ._masks import (
    CausalMask,
    ChunkedCausalMask,
    FullMask,
    LocalMask,
    Mask,
    MultiHeadMask,
)

if typing.TYPE_CHECKING:
    from ejkernel.kernels._triton.blocksparse_attention._mask import SparseMask

partial = functools.partial
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
NUM_LANES = 128
NUM_SUBLANES = 8

NN_DIM_NUMBERS = (((1,), (0,)), ((), ()))
NT_DIM_NUMBERS = (((1,), (1,)), ((), ()))

# mypy: ignore-errors


[docs]class SegmentIds(NamedTuple): """SegmentIds for Q and KV sequences. SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other. The static mask (e.g. causal) is "and-ed" with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations. Attributes: q: segment ids along the Q sequence kv: segment ids along the KV sequence """ q: jax.Array kv: jax.Array
SplashCustomReturnType = Union[jax.Array, tuple[jax.Array, tuple[jax.Array,]]] # noqa SplashResidualsType = tuple[ jax.Array, jax.Array, jax.Array, SegmentIds | None, jax.Array | None, jax.Array, jax.Array, mask_info_lib.MaskInfo | None, mask_info_lib.MaskInfo | None, ] MaskFunctionType = Callable[..., jax.Array]
[docs]def get_kernel_name( block_metadata: Mapping[str, Any], is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str, ) -> str: """Returns a unique name for all SplashAttention kernel variants.""" assert phase == "dq" or phase == "dkv" or phase == "fwd" assert not save_residuals or phase == "fwd" residuals = "" if save_residuals: residuals = "_residuals" elif phase == "fwd": residuals = "_no_residuals" attention_type = "mqa" if is_mqa else "mha" segments = "_segmented" if is_segmented else "" return f"splash_{attention_type}_{phase}{segments}{residuals}_" + "_".join( f"{k}={v}" for k, v in sorted(block_metadata.items()) )
@overload def _attention_reference( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, save_residuals: Literal[False], mask_value: float, custom_type: str, logits_soft_cap: float | None, ) -> jax.Array: ... @overload def _attention_reference( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, save_residuals: Literal[True], mask_value: float, custom_type: str, logits_soft_cap: float | None, ) -> tuple[jax.Array, tuple[jax.Array]]: ... def _attention_reference( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, mask_value: float, save_residuals: bool, custom_type: str, logits_soft_cap: float | None, ): return _attention_reference_default( mask, q, k, v, segment_ids, sinks, mask_value, save_residuals, custom_type, logits_soft_cap, ) def _attention_reference_default( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, mask_value: float, save_residuals: bool, custom_type: str, logits_soft_cap: float | None, ): del custom_type logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32)) if segment_ids is not None: mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :]) if logits_soft_cap is not None: logits = jnp.tanh(logits / logits_soft_cap) logits = logits * logits_soft_cap logits = jnp.where(mask, logits, mask_value) m = logits.max(axis=-1) sinks = None if sinks is None else sinks.astype(logits.dtype) m = m if sinks is None else jnp.maximum(m, sinks) s = jnp.exp(logits - m[..., None]) l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m)) s = s / l[..., None] o = jnp.einsum("st,td->sd", s, v.astype(jnp.float32)) lse = m + jnp.log(l) if save_residuals: return o, (lse,) return o
[docs]def attention_reference( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, custom_type: str = "flash", logits_soft_cap: float | None = None, ) -> SplashCustomReturnType: return _attention_reference( mask, q, k, v, segment_ids, sinks, mask_value=mask_value, save_residuals=save_residuals, custom_type=custom_type, logits_soft_cap=logits_soft_cap, )
def _attention_reference_custom_fwd( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, mask_value: float, save_residuals: bool, custom_type: str, logits_soft_cap: float | None, ): if save_residuals: raise NotImplementedError("Higher-order AD not supported.") o, (lse,) = _attention_reference( mask, q, k, v, segment_ids, sinks, mask_value=mask_value, save_residuals=True, custom_type=custom_type, logits_soft_cap=logits_soft_cap, ) return o, (mask, q, k, v, segment_ids, sinks, o, lse) def _attention_reference_custom_bwd( mask_value: float, save_residuals: bool, custom_type: str, logits_soft_cap: float | None, res, do: jax.Array, ) -> tuple[None, jax.Array, jax.Array, jax.Array, None, jax.Array | None]: del save_residuals mask, q, k, v, segment_ids, sinks, o, lse = res uncapped_logits = jnp.einsum("qc,kc->qk", q, k, preferred_element_type=jnp.float32) if logits_soft_cap is not None: logits = jnp.tanh(uncapped_logits / logits_soft_cap) logits = logits * logits_soft_cap else: logits = uncapped_logits if segment_ids is not None: mask = jnp.logical_and(mask, segment_ids.q[:, None] == segment_ids.kv[None, :]) logits = jnp.where(mask, logits, mask_value) p = jnp.exp(logits - lse[..., None]) do = do.astype(jnp.float32) dv = jnp.einsum("pt,pd->td", p, do).astype(v.dtype) dp = jnp.einsum("pd,td->pt", do, v.astype(jnp.float32)) if custom_type == "flash": di = jnp.sum(o.astype(jnp.float32) * do, axis=-1)[..., None] else: di = jnp.einsum("st,st->s", dp, p)[:, None] ds = (dp - di) * p if logits_soft_cap is not None: normalized = uncapped_logits / logits_soft_cap d = jnp.tanh(normalized) g = ds * (1 - d) ds = g + g * d dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype) dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype) dsinks = None if sinks is not None: sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) - lse[..., None].astype(jnp.float32)) dsinks = jnp.sum(sinks_exp.astype(o.dtype) * do * o) return None, dq, dk, dv, None, dsinks _attention_reference_custom = jax.custom_vjp( _attention_reference, nondiff_argnames=("mask_value", "save_residuals", "custom_type", "logits_soft_cap") ) _attention_reference_custom.defvjp(_attention_reference_custom_fwd, _attention_reference_custom_bwd)
[docs]def attention_reference_custom( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, custom_type: str = "flash", logits_soft_cap: float | None = None, ): return _attention_reference_custom( mask, q, k, v, segment_ids, sinks, mask_value, save_residuals, custom_type=custom_type, logits_soft_cap=logits_soft_cap, )
[docs]def make_attention_reference( mask: mask_lib.Mask | np.ndarray, is_mqa: bool, backward_impl: str = "vanilla", **params: Any, ) -> Callable: @partial( jax.jit, static_argnames=[ "mask_value", "save_residuals", "logits_soft_cap", ], ) def _wrapped( mask: jax.Array, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None = None, sinks: jax.Array | None = None, *, mask_value: float = DEFAULT_MASK_VALUE, save_residuals: bool = False, logits_soft_cap: float | None = None, ): if backward_impl == "custom": attn_impl = partial( attention_reference_custom, custom_type="flash", ) elif backward_impl == "custom_vanilla": attn_impl = partial( attention_reference_custom, custom_type="vanilla", ) else: attn_impl = attention_reference func = partial( attn_impl, mask_value=mask_value, save_residuals=save_residuals, logits_soft_cap=logits_soft_cap, **params, ) if is_mqa: func = jax.vmap(func, in_axes=(0, 0, None, None, None, 0)) is_grouped = False else: kv_heads = k.shape[0] assert kv_heads == v.shape[0] q_heads, q_seq_len, head_dim = q.shape is_grouped = kv_heads < q_heads if is_grouped: assert q_heads % kv_heads == 0 assert mask.shape[0] == q_heads q_heads_per_kv_head = q_heads // kv_heads q = q.reshape((kv_heads, q_heads_per_kv_head, q_seq_len, head_dim)) mask = mask.reshape((kv_heads, q_heads_per_kv_head, *mask.shape[1:])) if sinks is not None: sinks = sinks.reshape((kv_heads, q_heads_per_kv_head)) func = jax.vmap(func, in_axes=(0, 0, None, None, None, 0)) func = jax.vmap(func, in_axes=(0, 0, 0, 0, None, 0)) out = func(mask, q, k, v, segment_ids, sinks) if is_grouped: def reshape_activations(activations): if activations.ndim == 4: kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape return activations.reshape(kv_heads * q_heads_per_kv_head, q_seq_len, head_dim) return activations def reshape_residuals(residuals): if residuals.ndim == 3: kv_heads, q_heads_per_kv_head, q_seq_len = residuals.shape return residuals.reshape(kv_heads * q_heads_per_kv_head, q_seq_len) return residuals if save_residuals: assert isinstance(out, tuple) assert isinstance(out[1], tuple) return (reshape_activations(out[0]), (reshape_residuals(out[1][0]),)) else: return reshape_activations(out) else: return out return functools.partial(_wrapped, jnp.array(mask[:, :, :]))
make_masked_mha_reference = partial(make_attention_reference, is_mqa=False) make_masked_mqa_reference = partial(make_attention_reference, is_mqa=True)
[docs]class QKVLayout(enum.IntEnum): HEAD_DIM_MINOR = enum.auto() SEQ_MINOR = enum.auto()
[docs]def from_head_minor(vals: tuple[Any, ...], layout: QKVLayout): if layout == QKVLayout.HEAD_DIM_MINOR: return vals return (*vals[:-2], vals[-1], vals[-2])
[docs]@dataclasses.dataclass(frozen=True, slots=True) class BlockSizes: """Tile sizes parameterizing SplashAttention kernels. Those parameters have negligible effect on numerics, but affect performance greatly. Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to blocksparse_attention attention always takes the head dimension as the minormost one. """ block_q: int block_kv: int block_kv_compute: int | None = None block_q_dkv: int | None = None block_kv_dkv: int | None = None block_kv_dkv_compute: int | None = None block_q_dq: int | None = None block_kv_dq: int | None = None use_fused_bwd_kernel: bool = False q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR def __post_init__(self): if self.block_kv_compute is None: object.__setattr__(self, "block_kv_compute", self.block_kv) if self.block_kv_dkv_compute is None: object.__setattr__(self, "block_kv_dkv_compute", self.block_kv_dkv) if self.use_fused_bwd_kernel: if self.block_q_dq is not None or self.block_kv_dq is not None: raise ValueError("Block sizes for dq kernel are not needed with a fused kernel.") @property def has_backward_blocks(self) -> bool: backward_blocks = ( self.block_q_dkv, self.block_kv_dkv, self.block_kv_dkv_compute, ) if not self.use_fused_bwd_kernel: backward_blocks += (self.block_q_dq, self.block_kv_dq) return all(b is not None for b in backward_blocks)
[docs] @classmethod def get_default(cls): return BlockSizes( block_q=128, block_kv=128, block_kv_compute=128, block_q_dkv=128, block_kv_dkv=128, block_kv_dkv_compute=128, block_q_dq=128, block_kv_dq=128, )
def _next_nonzero( h, i, j, data_next_ref, block_mask_ref, m_next_ref, next_i=False, ): assert (data_next_ref is None) == (block_mask_ref is None) if data_next_ref is None and block_mask_ref is None: assert m_next_ref is None next_data = i if next_i else j return ( next_data, None, True, False, ) assert data_next_ref.shape == block_mask_ref.shape assert m_next_ref is None or data_next_ref.shape[0] == m_next_ref.shape[0] if data_next_ref.shape[0] == 1: h = 0 def to_i32(x): return x.astype(jnp.int32) is_nonzero = to_i32(block_mask_ref[h, i, j]) > 0 if m_next_ref is None: should_not_mask = True next_m = None else: should_not_mask = to_i32(block_mask_ref[h, i, j]) != 1 next_m = to_i32(m_next_ref[h, i, j]) next_j = to_i32(data_next_ref[h, i, j]) return next_j, next_m, is_nonzero, should_not_mask def _apply_mask_and_soft_cap( qk: jax.Array, mask_value: float, should_not_mask, mask_ref, q_sequence_ref, q_segment_ids_ref, kv_segment_ids_ref, *, logits_soft_cap: float, k_slice: pl.Slice, k_offset: int | jax.Array, bq: int, k_in_lanes=True, mask_function=None, ) -> jax.Array | tuple[jax.Array, jax.Array, jax.Array, jax.Array]: assert mask_ref is None or q_sequence_ref is None assert (q_sequence_ref is None) == (mask_function is None) masks = [] if mask_ref is not None: if k_in_lanes: mask = mask_ref[:, k_slice] else: mask = mask_ref[k_slice, :] masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape))) if mask_function is not None: if k_in_lanes: assert q_sequence_ref.shape == (bq, NUM_LANES) k_sequence = k_offset + jax.lax.broadcasted_iota(jnp.int32, (bq, k_slice.size), 1) repeats, rem = divmod(k_slice.size, NUM_LANES) assert rem == 0 q_sequence = pltpu.repeat(q_sequence_ref[...], repeats, axis=1) else: assert q_sequence_ref.shape == (NUM_SUBLANES, bq) k_sequence = k_offset + jax.lax.broadcasted_iota(jnp.int32, (k_slice.size, bq), 0) q_sequence = q_sequence_ref[:1, :] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape computed_mask = mask_function(q_sequence, k_sequence) if computed_mask.dtype != jnp.dtype(jnp.bool_): raise ValueError(f"Mask function must return a boolean-valued array, but got: {computed_mask.dtype}") masks.append(computed_mask) if q_segment_ids_ref is not None: if k_in_lanes: kv_ids = kv_segment_ids_ref[:1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") q_ids = pltpu.repeat(q_segment_ids_ref[:], repeats, axis=1) else: assert bq == q_segment_ids_ref.shape[-1] repeats, rem = divmod(bq, NUM_LANES) if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat(kv_segment_ids_ref[k_slice, :], repeats, axis=1) q_ids = q_segment_ids_ref[:1, :] masks.append(q_ids == kv_ids) def cap_logits(logits): if logits_soft_cap is not None: logits = jnp.tanh(qk / logits_soft_cap) return logits * logits_soft_cap else: return logits if masks: mask = functools.reduce(jnp.logical_and, masks) qk = cap_logits(qk) qk = jnp.where(mask, qk, mask_value) else: qk = cap_logits(qk) return qk
[docs]def flash_attention_kernel( data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, sinks_ref, mask_ref, q_sequence_ref, m_scratch_ref, l_scratch_ref, o_scratch_ref, o_ref, logsumexp_ref=None, *, mask_value: float, grid_width: int, bq: int, bkv: int, bkv_compute: int, head_dim_v: int, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, logits_soft_cap: float | None, mask_function: MaskFunctionType | None, ): float32 = jnp.float32 HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR head_dim_v_repeats = pl.cdiv(head_dim_v, NUM_LANES) h, i, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) @pl.when(j == 0) def init(): o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) if sinks_ref is not None: sinks = sinks_ref[h].astype(m_scratch_ref.dtype) m_scratch_ref[...] = sinks * jnp.ones_like(m_scratch_ref) l_scratch_ref[...] = jnp.ones_like(l_scratch_ref) else: m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) global_kv_index, _, should_run, should_not_mask = _next_nonzero( h, i, j, data_next_ref, block_mask_ref, mask_next_ref, ) def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] assert m_prev.shape == (bq, NUM_LANES) assert l_prev.shape == (bq, NUM_LANES) q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: k = k_ref[slice_k, :] else: k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) apply_mask_and_soft_cap = functools.partial( _apply_mask_and_soft_cap, qk, mask_value, should_not_mask, mask_ref, q_sequence_ref, q_segment_ids_ref, kv_segment_ids_ref, logits_soft_cap=logits_soft_cap, k_slice=slice_k, k_offset=global_kv_index * bkv + kv_compute_index * bkv_compute, bq=bq, mask_function=mask_function, ) qk = apply_mask_and_soft_cap() m_curr = qk.max(axis=-1)[:, None] assert m_curr.shape == (bq, 1) m_next = jnp.maximum(m_prev, m_curr) assert m_next.shape == (bq, NUM_LANES) bkv_repeats, rem = divmod(bkv_compute, NUM_LANES) if rem != 0: raise NotImplementedError(f"{bkv_compute=} should be a multiple of {NUM_LANES}") s_curr = jnp.exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1)) assert s_curr.shape == (bq, bkv_compute) l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) assert l_curr.shape == (bq, NUM_LANES) alpha = jnp.exp(m_prev - m_next) l_next = l_curr + alpha * l_prev m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: v = v_ref[slice_k, :] else: v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1)[..., : o_scratch_ref.shape[-1]] o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr @pl.when(should_run) def run(): assert bkv % bkv_compute == 0 num_iters = k_ref.shape[0 if k_layout == HEAD_DIM_MINOR else 1] // bkv_compute lax.fori_loop(0, num_iters, body, None, unroll=True) @pl.when(j == grid_width - 1) def end(): l = l_scratch_ref[...] l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=1)[..., : o_scratch_ref.shape[-1]] o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) if logsumexp_ref is not None: assert logsumexp_ref.shape == (bq, NUM_LANES) logsumexp_ref[...] = (jnp.log(l) + m_scratch_ref[...]).astype(logsumexp_ref.dtype) m_scratch_ref[...] = jnp.zeros_like(m_scratch_ref) l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref)
@overload def _splash_attention_forward( fwd_mask_info: mask_info_lib.MaskInfo, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, save_residuals: Literal[False] = False, logits_soft_cap: float | None = None, ) -> jax.Array: ... @overload def _splash_attention_forward( fwd_mask_info: mask_info_lib.MaskInfo, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, save_residuals: Literal[True], logits_soft_cap: float | None = None, ) -> SplashCustomReturnType: ... def _div(dividend: int, divisor: int): if divisor == 1: return dividend return lax.div(dividend, divisor) def _splash_attention_forward( fwd_mask_info: mask_info_lib.MaskInfo, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, save_residuals: bool, mask_function: MaskFunctionType | None, logits_soft_cap: float | None = None, interpret: bool = False, ) -> SplashCustomReturnType: num_q_heads, q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] bq, bkv = block_sizes.block_q, block_sizes.block_kv bkv_compute = block_sizes.block_kv_compute if is_mqa: expected_kv_rank = 2 kv_head_dimension = 1 kv_seq_len_dimension = 0 num_kv_heads = 1 else: expected_kv_rank = 3 kv_head_dimension = 2 kv_seq_len_dimension = 1 num_kv_heads = k.shape[0] partial_mask_blocks = fwd_mask_info.partial_mask_blocks if partial_mask_blocks is not None and jnp.dtype(partial_mask_blocks.dtype) != np.bool_: raise ValueError(f"partial_mask_blocks must be of type np.bool_ but got {partial_mask_blocks.dtype}") if len(k.shape) != expected_kv_rank: raise ValueError(f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a {len(k.shape)}-dim one.") if k.shape[kv_head_dimension] != head_dim_qk: raise ValueError( f"Expected 'key' head dimension to be: {head_dim_qk}. Instead got: {k.shape[kv_head_dimension]}." ) if not is_mqa and num_q_heads % num_kv_heads != 0: raise ValueError( f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" f" multiple of the number of 'query' heads ({num_q_heads})" ) if k.shape[:-1] != v.shape[:-1]: raise ValueError(f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same leading dimensions.") if bkv % bkv_compute: raise ValueError(f"{bkv=} must be a multiple of {bkv_compute=}.") if bkv_compute % NUM_LANES: raise ValueError(f"{bkv_compute=} must be a multiple of {NUM_LANES}.") kv_seq_len = k.shape[kv_seq_len_dimension] q_heads_per_kv_head = num_q_heads // num_kv_heads if segment_ids is not None: if segment_ids.q.shape != (q_seq_len,): raise ValueError(f"Invalid shape for q segment_ids: {segment_ids.q.shape}. Expected: {(q_seq_len,)}") if segment_ids.kv.shape != (kv_seq_len,): raise ValueError(f"Invalid shape for kv segment_ids: {segment_ids.kv.shape}. Expected: {(kv_seq_len,)}") q_layout = block_sizes.q_layout def q_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): del j, data_next_ref, mask_next_ref, block_mask_ref return from_head_minor((h, i, 0), q_layout) def out_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): del j, data_next_ref, mask_next_ref, block_mask_ref return h, i, 0 k_layout = block_sizes.k_layout def k_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) return from_head_minor((*prefix, next_j, 0), k_layout) v_layout = block_sizes.v_layout def v_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) return from_head_minor((*prefix, next_j, 0), v_layout) def mask_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): _, next_m, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) return next_m, 0, 0 def q_segment_ids_index_map(h, i, j, *_): del h, j return i, 0 def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref=None): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) return 0, next_j in_specs = [ pl.BlockSpec(from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map), pl.BlockSpec( from_head_minor((bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), k_layout), k_index_map, ), pl.BlockSpec( from_head_minor((bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), v_layout), v_index_map, ), ] if segment_ids is not None: in_specs += [ pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map), pl.BlockSpec((NUM_SUBLANES, bkv), kv_segment_ids_index_map), ] q_segment_ids = jax.lax.broadcast_in_dim(segment_ids.q, (q_seq_len, NUM_LANES), (0,)) kv_segment_ids = jax.lax.broadcast_in_dim(segment_ids.kv, (NUM_SUBLANES, kv_seq_len), (1,)) else: in_specs += [None, None] q_segment_ids = kv_segment_ids = None if sinks is not None: assert sinks.shape == (num_q_heads,) in_specs += [pl.BlockSpec((num_q_heads,), lambda h, i, j, *_: (0,), memory_space=pltpu.SMEM)] sinks = sinks.astype(jnp.float32) else: in_specs += [None] if fwd_mask_info.partial_mask_blocks is not None: in_specs.append(pl.BlockSpec((None, bq, bkv), mask_index_map)) else: in_specs.append(None) assert fwd_mask_info.partial_mask_blocks is None or fwd_mask_info.q_sequence is None if fwd_mask_info.q_sequence is not None: q_sequence = jax.lax.broadcast_in_dim(fwd_mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,)) in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map)) else: q_sequence = None in_specs.append(None) num_scalar_prefetch = 3 out_shapes = [ jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), jax.ShapeDtypeStruct((bq, head_dim_v), jnp.float32), jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), q.dtype), ] out_specs = [ pl.BlockSpec((bq, NUM_LANES), lambda h, i, j, *_: (0, 0)), pl.BlockSpec((bq, NUM_LANES), lambda h, i, j, *_: (0, 0)), pl.BlockSpec((bq, head_dim_v), lambda h, i, j, *_: (0, 0)), pl.BlockSpec((None, bq, head_dim_v), out_index_map), ] if save_residuals: out_shapes += [ jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32), ] def logsumexp_index_map(h, i, *_): return h, i, 0 out_specs += [ pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map), ] else: out_shapes += [None] out_specs += [None] kernel_name = get_kernel_name( dataclasses.asdict(block_sizes), is_mqa=is_mqa, save_residuals=save_residuals, is_segmented=segment_ids is not None, phase="fwd", ) if fwd_mask_info.data_next is not None: grid_width = fwd_mask_info.data_next.shape[-1] else: grid_width = kv_seq_len // bkv grid = (num_q_heads, q_seq_len // bq, grid_width) with jax.named_scope(kernel_name): all_out = pl.pallas_call( partial( flash_attention_kernel, mask_value=mask_value, grid_width=grid_width, bq=bq, bkv=bkv, bkv_compute=bkv_compute, head_dim_v=head_dim_v, q_layout=q_layout, k_layout=k_layout, v_layout=v_layout, logits_soft_cap=logits_soft_cap, mask_function=mask_function, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=num_scalar_prefetch, in_specs=in_specs, out_specs=out_specs, grid=grid, ), compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary"), ), out_shape=out_shapes, name=kernel_name, interpret=interpret, )( fwd_mask_info.data_next, fwd_mask_info.block_mask, fwd_mask_info.mask_next, q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.swapaxes(-1, -2), k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.swapaxes(-1, -2), v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, sinks, fwd_mask_info.partial_mask_blocks, q_sequence, ) ( _, _, _, out, lse, ) = all_out if save_residuals: assert lse is not None lse = lse[..., 0] if residual_checkpoint_name is not None: out = ad_checkpoint.checkpoint_name(out, name=residual_checkpoint_name) if lse is not None: lse = ad_checkpoint.checkpoint_name(lse, name=residual_checkpoint_name) if save_residuals: return out, (lse,) return out @partial( jax.custom_vjp, nondiff_argnames=( "save_residuals", "mask_value", "is_mqa", "block_sizes", "residual_checkpoint_name", "mask_function", "logits_soft_cap", "interpret", ), ) def _splash_attention_custom( fwd_mask_info: mask_info_lib.MaskInfo, dq_mask_info: mask_info_lib.MaskInfo | None, dkv_mask_info: mask_info_lib.MaskInfo | None, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, save_residuals: bool, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, logits_soft_cap: float | None = None, interpret: bool = False, ) -> SplashCustomReturnType: del dq_mask_info, dkv_mask_info return _splash_attention_forward( fwd_mask_info, q, k, v, segment_ids, sinks=sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, residual_checkpoint_name=residual_checkpoint_name, save_residuals=save_residuals, mask_function=mask_function, logits_soft_cap=logits_soft_cap, interpret=interpret, ) def _splash_attention_fwd( fwd_mask_info: mask_info_lib.MaskInfo, dq_mask_info: mask_info_lib.MaskInfo | None, dkv_mask_info: mask_info_lib.MaskInfo | None, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None, sinks: jax.Array | None, save_residuals: bool, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, logits_soft_cap: float | None = None, interpret: bool = False, ) -> tuple[ tuple[jax.Array], SplashResidualsType, ]: if save_residuals: raise NotImplementedError("Higher-order AD not supported") out, (lse,) = _splash_attention_forward( fwd_mask_info, q, k, v, segment_ids, sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, residual_checkpoint_name=residual_checkpoint_name, save_residuals=True, mask_function=mask_function, logits_soft_cap=logits_soft_cap, interpret=interpret, ) return out, ( q, k, v, segment_ids, sinks, out, lse, dq_mask_info, dkv_mask_info, ) def _flash_attention_dq_kernel( data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, sinks_ref, logsumexp_ref, do_ref, di_ref, mask_ref, q_sequence_ref, dq_scratch_ref, dq_ref, *, mask_value: float, grid_width: int, bq: int, bkv: int, logits_soft_cap: float | None = None, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, mask_function: MaskFunctionType | None, ): del sinks_ref float32 = jnp.float32 HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR h, i, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) @pl.when(j == 0) def init(): dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) global_kv_index, _, should_run, should_not_mask = _next_nonzero( h, i, j, data_next_ref, block_mask_ref, mask_next_ref ) @pl.when(should_run) def run(): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T k = k_ref[...] v = v_ref[...] lse = jnp.expand_dims(logsumexp_ref[0], -1) do = do_ref[...] di = jnp.expand_dims(di_ref[0], -1) qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) qk = _apply_mask_and_soft_cap( qk_uncapped, mask_value, should_not_mask, mask_ref, q_sequence_ref, q_segment_ids_ref, kv_segment_ids_ref, logits_soft_cap=logits_soft_cap, k_slice=pl.ds(0, bkv), k_offset=global_kv_index * bkv, bq=bq, mask_function=mask_function, ) p = jnp.exp(qk - lse) dp_dims = NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS dp = lax.dot_general( do.astype(v.dtype), v, dp_dims, preferred_element_type=jnp.float32, ) ds = (dp - di) * p if logits_soft_cap is not None: normalized = qk_uncapped / logits_soft_cap d = jnp.tanh(normalized) g = ds * (1 - d) ds = g + g * d dq_dims = NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS dq_scratch_ref[...] += lax.dot_general( ds.astype(k.dtype), k, dq_dims, preferred_element_type=jnp.float32, ) @pl.when(j == grid_width - 1) def end(): dq_ref[...] = dq_scratch_ref[...].astype(dq_ref.dtype) dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) def _splash_attention_bwd_dq( q, k, v, segment_ids, sinks, lse, do, di, *, bq: int, bkv: int, is_mqa: bool, mask_info: mask_info_lib.MaskInfo, mask_value: float, logits_soft_cap: float | None, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, mask_function: MaskFunctionType | None, interpret: bool, ): num_q_heads, q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] if is_mqa: kv_seq_len = k.shape[0] num_kv_heads = 1 else: kv_seq_len = k.shape[1] num_kv_heads = k.shape[0] if bq > q_seq_len: raise ValueError(f"{bq=} should not be greater than {q_seq_len=}") if bkv > kv_seq_len: raise ValueError(f"{bkv=} should not be greater than {kv_seq_len=}") if not is_mqa and num_q_heads % num_kv_heads != 0: raise ValueError( f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" f" multiple of the number of 'query' heads ({num_q_heads})" ) if k.shape[:-1] != v.shape[:-1]: raise ValueError(f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same leading dimensions.") if bkv % NUM_LANES: raise ValueError(f"{bkv=} must be a multiple of {NUM_LANES}.") q_heads_per_kv_head = num_q_heads // num_kv_heads if mask_info.data_next is not None: grid_width = mask_info.data_next.shape[-1] else: grid_width = kv_seq_len // bkv grid = (num_q_heads, q_seq_len // bq, grid_width) def o_index_map(h, i, *_): return h, i, 0 o_spec = pl.BlockSpec((None, bq, head_dim_v), o_index_map) def q_index_map(h, i, *_): return from_head_minor((h, i, 0), q_layout) q_spec = pl.BlockSpec(from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map) def k_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) return from_head_minor((*prefix, next_j, 0), k_layout) k_spec = pl.BlockSpec( from_head_minor((bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), k_layout), k_index_map, ) def v_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) return from_head_minor((*prefix, next_j, 0), v_layout) v_spec = pl.BlockSpec( from_head_minor((bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), v_layout), v_index_map, ) def mask_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): _, next_m, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) return next_m, 0, 0 mask_spec = pl.BlockSpec((None, bq, bkv), mask_index_map) def q_segment_ids_index_map(h, i, j, *_): del h, j return i, 0 if segment_ids is not None: def kv_segment_ids_index_map(h, i, j, data_next_ref, block_mask_ref, mask_next_ref, *_): next_j, *_ = _next_nonzero(h, i, j, data_next_ref, block_mask_ref, mask_next_ref) return 0, next_j q_segment_spec = pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map) kv_segment_spec = pl.BlockSpec((NUM_SUBLANES, bkv), kv_segment_ids_index_map) q_segment_ids = jax.lax.broadcast_in_dim(segment_ids.q, (q_seq_len, NUM_LANES), (0,)) kv_segment_ids = jax.lax.broadcast_in_dim(segment_ids.kv, (NUM_SUBLANES, kv_seq_len), (1,)) else: q_segment_spec = kv_segment_spec = None q_segment_ids = kv_segment_ids = None if sinks is not None: assert sinks.shape == (num_q_heads,) sinks_spec = pl.BlockSpec((num_q_heads,), lambda h, *_: (0,), memory_space=pltpu.SMEM) sinks = sinks.astype(jnp.float32) else: sinks_spec = None do_spec = o_spec def logsumexp_index_map(h, i, *_): return h, 0, i lse = jnp.expand_dims(lse, axis=-2) logsumexp_spec = pl.BlockSpec((None, 1, bq), logsumexp_index_map) assert lse.ndim == len(logsumexp_spec.block_shape) di = jnp.expand_dims(di, axis=-2) di_spec = pl.BlockSpec((None, 1, bq), logsumexp_index_map) assert di.ndim == len(di_spec.block_shape) in_specs = [ q_spec, k_spec, v_spec, q_segment_spec, kv_segment_spec, sinks_spec, logsumexp_spec, do_spec, di_spec, ] if mask_info.partial_mask_blocks is not None: in_specs.append(mask_spec) else: in_specs.append(None) assert mask_info.partial_mask_blocks is None or mask_info.q_sequence is None if mask_info.q_sequence is not None: q_sequence = jax.lax.broadcast_in_dim(mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,)) in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map)) else: q_sequence = None in_specs.append(None) out_shapes = [ jax.ShapeDtypeStruct((bq, head_dim_qk), jnp.float32), jax.ShapeDtypeStruct(q.shape, q.dtype), ] out_specs = [ pl.BlockSpec((bq, head_dim_qk), lambda *_: (0, 0)), pl.BlockSpec((None, bq, head_dim_qk), lambda h, i, *_: (h, i, 0)), ] kernel = functools.partial( _flash_attention_dq_kernel, grid_width=grid_width, mask_value=mask_value, bq=bq, bkv=bkv, logits_soft_cap=logits_soft_cap, q_layout=q_layout, k_layout=k_layout, v_layout=v_layout, mask_function=mask_function, ) num_scalar_prefetch = 3 kernel_name = get_kernel_name( dict( block_q_dq=bq, block_kv_dq=bkv, q_layout=q_layout, k_layout=k_layout, v_layout=v_layout, ), is_mqa=is_mqa, save_residuals=False, is_segmented=segment_ids is not None, phase="dq", ) with jax.named_scope(kernel_name): _, dq = pl.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=num_scalar_prefetch, in_specs=in_specs, out_specs=out_specs, grid=grid, ), out_shape=out_shapes, compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, interpret=interpret, )( mask_info.data_next, mask_info.block_mask, mask_info.mask_next, q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.swapaxes(-1, -2), k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.swapaxes(-1, -2), v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, sinks, lse, do, di, mask_info.partial_mask_blocks, q_sequence, ) return dq def _flash_attention_dkv_kernel( data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, sinks_ref, logsumexp_ref, do_ref, di_ref, mask_ref, q_sequence_ref, dq_scratch_ref, dk_scratch_ref, dv_scratch_ref, dq_ref, dk_ref, dv_ref, *, num_q_heads: int, num_kv_heads: int, mask_value: float, grid_width: int, bq: int, bkv_compute: int, is_mqa: bool, logits_soft_cap: float | None, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, bkv: int, mask_function: MaskFunctionType | None, ): del sinks_ref HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR kv_index, q_head_index, q_index = ( pl.program_id(0), pl.program_id(1), pl.program_id(2), ) should_initialize = q_index == 0 q_heads_per_kv_heads = None q_head_index_per_kv_head = None if is_mqa: should_initialize = jnp.logical_and(should_initialize, q_head_index == 0) elif num_kv_heads < num_q_heads: q_heads_per_kv_heads = num_q_heads // num_kv_heads q_head_index_per_kv_head = lax.rem(q_head_index, q_heads_per_kv_heads) should_initialize = jnp.logical_and(should_initialize, q_head_index_per_kv_head == 0) @pl.when(should_initialize) def init(): dk_scratch_ref[...] = jnp.zeros_like(dk_scratch_ref) dv_scratch_ref[...] = jnp.zeros_like(dv_scratch_ref) _, _, should_run, should_not_mask = _next_nonzero( q_head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) def body(i, _): slice_k = pl.ds(i * bkv_compute, bkv_compute) q = q_ref[...] def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: return ref[slice_k, :] return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) lse = logsumexp_ref[:1, :] do = do_ref[...] di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general(k, q, qk_dims, preferred_element_type=jnp.float32) qk = _apply_mask_and_soft_cap( qk_uncapped, mask_value, should_not_mask, mask_ref, q_sequence_ref, q_segment_ids_ref, kv_segment_ids_ref, logits_soft_cap=logits_soft_cap, k_slice=slice_k, k_offset=kv_index * bkv + i * bkv_compute, bq=bq, k_in_lanes=False, mask_function=mask_function, ) p = jnp.exp(qk - lse) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, preferred_element_type=jnp.float32, ) ds = (dp - di) * p if logits_soft_cap is not None: normalized = qk_uncapped / logits_soft_cap d = jnp.tanh(normalized) g = ds * (1 - d) ds = g + g * d dk_dims = NN_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS dk = lax.dot_general(ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32) dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, preferred_element_type=jnp.float32, ) if dq_scratch_ref is not None: dq_scratch_ref[...] += dq else: assert dq_ref is not None dq_ref[...] = dq.astype(dq_ref.dtype) if dq_scratch_ref is not None: dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) elif dq_scratch_ref is None and dq_ref is not None: dq_ref[...] = jnp.zeros_like(dq_ref) @pl.when(should_run) def run(): num_iters = k_ref.shape[0 if k_layout is HEAD_DIM_MINOR else 1] // bkv_compute lax.fori_loop(0, num_iters, body, None, unroll=True) if dq_scratch_ref is not None: assert dq_ref is not None dq_ref[...] = dq_scratch_ref[...].astype(dq_ref.dtype) should_write = q_index == grid_width - 1 if is_mqa: should_write = jnp.logical_and(should_write, q_head_index == num_q_heads - 1) elif num_kv_heads < num_q_heads: should_write = jnp.logical_and(should_write, q_head_index_per_kv_head == q_heads_per_kv_heads - 1) @pl.when(should_write) def end(): dk_ref[...] = dk_scratch_ref[...].astype(dk_ref.dtype) dv_ref[...] = dv_scratch_ref[...].astype(dv_ref.dtype) if dq_scratch_ref is not None: dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) dk_scratch_ref[...] = jnp.zeros_like(dk_scratch_ref) dv_scratch_ref[...] = jnp.zeros_like(dv_scratch_ref) def _splash_attention_bwd_dkv( q, k, v, segment_ids, sinks, lse, do, di, *, bq: int, bkv: int, bkv_compute: int, is_mqa: bool, mask_info: mask_info_lib.MaskInfo, mask_value: float, logits_soft_cap: float | None, use_fused_bwd_kernel: bool, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, mask_function: MaskFunctionType | None, interpret: bool, ): num_q_heads, q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] if is_mqa: num_kv_heads, kv_seq_len = 1, k.shape[0] else: num_kv_heads, kv_seq_len, _ = k.shape if bq > q_seq_len: raise ValueError(f"{bq=} should not be greater than {q_seq_len=}") if bkv > kv_seq_len: raise ValueError(f"{bkv=} should not be greater than {kv_seq_len=}") if bkv_compute > bkv: raise ValueError(f"{bkv_compute=} should not be greater than {bkv=}") if bkv % bkv_compute: raise ValueError(f"{bkv=} should be a multiple of {bkv_compute=}") if not is_mqa and num_q_heads % num_kv_heads != 0: raise ValueError( f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" f" multiple of the number of 'query' heads ({num_q_heads})" ) if k.shape[:-1] != v.shape[:-1]: raise ValueError(f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same leading dimensions.") q_heads_per_kv_head = num_q_heads // num_kv_heads if mask_info.data_next is not None: grid_width = mask_info.data_next.shape[-2] else: grid_width = q_seq_len // bq grid = ( kv_seq_len // bkv, num_q_heads, grid_width, ) def o_index_map( kv_index, head_index, q_index, data_next_ref, block_mask_ref, mask_next_ref=None, ): next_i, *_ = _next_nonzero( head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) return head_index, next_i, 0 o_spec = pl.BlockSpec((None, bq, head_dim_v), o_index_map) def q_index_map( kv_index, head_index, q_index, data_next_ref, block_mask_ref, mask_next_ref=None, ): next_i, *_ = _next_nonzero( head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) return from_head_minor((head_index, next_i, 0), q_layout) q_spec = pl.BlockSpec(from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map) def k_index_map(kv_index, head_index, *_): prefix = () if is_mqa else (_div(head_index, q_heads_per_kv_head),) return from_head_minor((*prefix, kv_index, 0), k_layout) k_spec = pl.BlockSpec( from_head_minor( (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), k_layout, ), k_index_map, ) def v_index_map(kv_index, head_index, *_): prefix = () if is_mqa else (_div(head_index, q_heads_per_kv_head),) return from_head_minor((*prefix, kv_index, 0), v_layout) v_spec = pl.BlockSpec( from_head_minor( (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), v_layout, ), v_index_map, ) if use_fused_bwd_kernel: def dq_index_map(kv_index, head_index, q_index, *_): return (kv_index, head_index, q_index, 0) dq_spec = pl.BlockSpec((None, None, bq, head_dim_qk), dq_index_map) dq_shape = jax.ShapeDtypeStruct((kv_seq_len // bkv, *q.shape), q.dtype) if bkv == bkv_compute: dq_scratch_spec = dq_scratch_shape = None else: dq_scratch_spec = pl.BlockSpec((bq, head_dim_qk), lambda *_: (0, 0)) dq_scratch_shape = jax.ShapeDtypeStruct((bq, head_dim_qk), jnp.float32) else: dq_spec = dq_shape = dq_scratch_spec = dq_scratch_shape = None def dkv_index_map(kv_index, head_index, *_): prefix = () if is_mqa else (_div(head_index, q_heads_per_kv_head),) return (*prefix, kv_index, 0) dk_spec = pl.BlockSpec( (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), dkv_index_map, ) dv_spec = pl.BlockSpec( (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), dkv_index_map, ) def mask_index_map( kv_index, head_index, q_index, data_next_ref, block_mask_ref, mask_next_ref, ): _, next_m, *_ = _next_nonzero( head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) return next_m, 0, 0 mask_spec = pl.BlockSpec((None, bkv, bq), mask_index_map) def q_segment_ids_index_map( kv_index, head_index, q_index, data_next_ref, block_mask_ref, mask_next_ref=None, ): next_i, *_ = _next_nonzero( head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) return 0, next_i if segment_ids is not None: def kv_segment_ids_index_map(kv_index, *_): return kv_index, 0 q_segment_spec = pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map) kv_segment_spec = pl.BlockSpec((bkv, NUM_LANES), kv_segment_ids_index_map) q_segment_ids = jax.lax.broadcast_in_dim(segment_ids.q, (NUM_SUBLANES, q_seq_len), (1,)) kv_segment_ids = jax.lax.broadcast_in_dim(segment_ids.kv, (kv_seq_len, NUM_LANES), (0,)) else: q_segment_spec = kv_segment_spec = None q_segment_ids = kv_segment_ids = None if sinks is not None: assert sinks.shape == (num_q_heads,) sinks_spec = pl.BlockSpec((num_q_heads,), lambda kv_index, h, *_: (0,), memory_space=pltpu.SMEM) sinks = sinks.astype(jnp.float32) else: sinks_spec = None do_spec = o_spec def logsumexp_index_map( kv_index, head_index, q_index, data_next_ref, block_mask_ref, mask_next_ref=None, ): next_i, *_ = _next_nonzero( head_index, q_index, kv_index, data_next_ref, block_mask_ref, mask_next_ref, next_i=True, ) return head_index, 0, next_i assert lse.shape == di.shape == (num_q_heads, q_seq_len) logsumexp_shape = (num_q_heads, NUM_SUBLANES, q_seq_len) lse = jnp.broadcast_to(jnp.expand_dims(lse, -2), logsumexp_shape) logsumexp_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) assert lse.ndim == len(logsumexp_spec.block_shape) di = jnp.broadcast_to(jnp.expand_dims(di, -2), logsumexp_shape) di_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) assert di.ndim == len(di_spec.block_shape) in_specs = [ q_spec, k_spec, v_spec, q_segment_spec, kv_segment_spec, sinks_spec, logsumexp_spec, do_spec, di_spec, ] if mask_info.partial_mask_blocks is not None: in_specs.append(mask_spec) else: in_specs.append(None) if mask_info.q_sequence is not None: in_specs.append(pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map)) q_sequence = jax.lax.broadcast_in_dim(mask_info.q_sequence, (NUM_SUBLANES, q_seq_len), (1,)) else: q_sequence = None in_specs.append(None) out_shapes = [ dq_scratch_shape, jax.ShapeDtypeStruct((bkv, head_dim_qk), jnp.float32), jax.ShapeDtypeStruct((bkv, head_dim_v), jnp.float32), dq_shape, jax.ShapeDtypeStruct(k.shape, k.dtype), jax.ShapeDtypeStruct(v.shape, v.dtype), ] out_specs = [ dq_scratch_spec, pl.BlockSpec((bkv, head_dim_qk), lambda *_: (0, 0)), pl.BlockSpec((bkv, head_dim_v), lambda *_: (0, 0)), dq_spec, dk_spec, dv_spec, ] kernel = functools.partial( _flash_attention_dkv_kernel, mask_value=mask_value, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, is_mqa=is_mqa, grid_width=grid_width, bq=bq, bkv_compute=bkv_compute, logits_soft_cap=logits_soft_cap, q_layout=q_layout, k_layout=k_layout, v_layout=v_layout, bkv=bkv, mask_function=mask_function, ) num_scalar_prefetch = 3 kernel_name = get_kernel_name( dict( block_q_dkv=bq, block_kv_dkv=bkv, block_kv_dkv_compute=bkv_compute, q_layout=q_layout, k_layout=k_layout, v_layout=v_layout, ), is_mqa=is_mqa, save_residuals=False, is_segmented=segment_ids is not None, phase="dkv", ) with jax.named_scope(kernel_name): _, _, _, dq_unreduced, dk, dv = pl.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=num_scalar_prefetch, in_specs=in_specs, out_specs=out_specs, grid=grid, ), out_shape=out_shapes, compiler_params=pltpu.CompilerParams(dimension_semantics=("arbitrary", "arbitrary", "arbitrary")), name=kernel_name, interpret=interpret, )( mask_info.data_next, mask_info.block_mask, mask_info.mask_next, q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.swapaxes(-1, -2), k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.swapaxes(-1, -2), v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.swapaxes(-1, -2), q_segment_ids, kv_segment_ids, sinks, lse, do, di, mask_info.partial_mask_blocks, q_sequence, ) if use_fused_bwd_kernel: assert dq_unreduced is not None dq = dq_unreduced.sum(axis=0) else: assert dq_unreduced is None dq = None return dq, dk, dv def _splash_attention_bwd( save_residuals: bool, mask_value: float, is_mqa: bool, block_sizes: BlockSizes, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, logits_soft_cap: float | None, interpret: bool, res: SplashResidualsType, do: jax.Array, ) -> tuple[ mask_info_lib.MaskInfo | None, mask_info_lib.MaskInfo | None, mask_info_lib.MaskInfo | None, jax.Array, jax.Array, jax.Array, SegmentIds | None, jax.Array | None, ]: del save_residuals, residual_checkpoint_name if not block_sizes.has_backward_blocks: raise ValueError("Need to specify backward blocks.") bq_dq, bkv_dq = block_sizes.block_q_dq, block_sizes.block_kv_dq bq_dkv, bkv_dkv_memory, bkv_dkv_compute = ( block_sizes.block_q_dkv, block_sizes.block_kv_dkv, block_sizes.block_kv_dkv_compute, ) use_fused_bwd_kernel = block_sizes.use_fused_bwd_kernel ( q, k, v, segment_ids, sinks, o, lse, dq_mask_info, dkv_mask_info, ) = res di = jnp.einsum("hsd,hsd->hs", o.astype(jnp.float32), do.astype(jnp.float32)) dq, dk, dv = _splash_attention_bwd_dkv( q, k, v, segment_ids, sinks, lse, do, di, bq=bq_dkv, bkv=bkv_dkv_memory, bkv_compute=bkv_dkv_compute, is_mqa=is_mqa, mask_info=dkv_mask_info, mask_value=mask_value, logits_soft_cap=logits_soft_cap, use_fused_bwd_kernel=use_fused_bwd_kernel, q_layout=block_sizes.q_layout, k_layout=block_sizes.k_layout, v_layout=block_sizes.v_layout, mask_function=mask_function, interpret=interpret, ) if not use_fused_bwd_kernel: assert dq is None dq = _splash_attention_bwd_dq( q, k, v, segment_ids, sinks, lse, do, di, bq=bq_dq, bkv=bkv_dq, is_mqa=is_mqa, mask_info=dq_mask_info, mask_value=mask_value, logits_soft_cap=logits_soft_cap, q_layout=block_sizes.q_layout, k_layout=block_sizes.k_layout, v_layout=block_sizes.v_layout, mask_function=mask_function, interpret=interpret, ) assert dq is not None dsinks = None if sinks is not None: sinks_exp = -jnp.exp(sinks[..., None, None].astype(jnp.float32) - lse[..., None].astype(jnp.float32)) dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2)) return ( None, None, None, dq, dk, dv, None, dsinks, ) _splash_attention_custom.defvjp(_splash_attention_fwd, _splash_attention_bwd) @partial( jax.jit, static_argnames=[ "is_mqa", "block_sizes", "save_residuals", "mask_value", "logits_soft_cap", "residual_checkpoint_name", "mask_function", "interpret", ], ) def _splash_attention( fwd_mask_info: mask_info_lib.MaskInfo, dq_mask_info: mask_info_lib.MaskInfo | None, dkv_mask_info: mask_info_lib.MaskInfo | None, q: jax.Array, k: jax.Array, v: jax.Array, segment_ids: SegmentIds | None = None, sinks: jax.Array | None = None, *, is_mqa: bool, block_sizes: BlockSizes | None, save_residuals: bool, mask_value: float, logits_soft_cap: float | None, residual_checkpoint_name: str | None, mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: """ For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). This shape allows sharding across both head count and query sequence dimensions. Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be collapsed into a single dimension before being passed to the kernel. """ def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): if mask_info is None or mask_info.partial_mask_blocks is None: return mask_info return mask_info._replace( partial_mask_blocks=mask_info.partial_mask_blocks.reshape(-1, *mask_info.partial_mask_blocks.shape[-2:]) ) fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, dkv_mask_info, q, k, v, segment_ids, sinks, mask_value=mask_value, is_mqa=is_mqa, block_sizes=block_sizes, save_residuals=save_residuals, logits_soft_cap=logits_soft_cap, residual_checkpoint_name=residual_checkpoint_name, mask_function=mask_function, interpret=interpret, )
[docs]@jax.tree_util.register_pytree_node_class class SplashAttentionKernel: def __init__( self, fwd_mask_info: mask_info_lib.MaskInfo, dq_mask_info: mask_info_lib.MaskInfo | None, dkv_mask_info: mask_info_lib.MaskInfo | None, **kwargs, ): self.kwargs = kwargs self.fwd_mask_info = fwd_mask_info self.dq_mask_info = dq_mask_info self.dkv_mask_info = dkv_mask_info def __call__(self, *args, **kwargs) -> SplashCustomReturnType: return _splash_attention( self.fwd_mask_info, self.dq_mask_info, self.dkv_mask_info, *args, **kwargs, **self.kwargs, )
[docs] def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): """Returns a value that can be used as a shard_map partition spec for the kernel.""" if self.fwd_mask_info.data_next is not None: block_mask_shape = self.fwd_mask_info.data_next.shape try: shard_shape = sharding.shard_shape(block_mask_shape) except ValueError as exc: raise ValueError("The sharding must divide the mask blocks evenly between devices") from exc if block_mask_shape[-1] != shard_shape[-1]: raise ValueError("Sharding the kv sequence dimension is not supported") spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() partial_mask_blocks_spec = spec if self.fwd_mask_info.is_dynamic_mask else replicated q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec if self.fwd_mask_info.q_sequence is not None else None, ) return SplashAttentionKernel( mask_info_specs, mask_info_specs if self.dq_mask_info is not None else None, mask_info_specs if self.dkv_mask_info is not None else None, **self.kwargs, )
[docs] def tree_flatten(self): return ( (self.fwd_mask_info, self.dq_mask_info, self.dkv_mask_info), self.kwargs, )
[docs] @classmethod def tree_unflatten(cls, kwargs, values): fwd_mask_info, dq_mask_info, dkv_mask_info = values dq_mask_info = mask_info_lib.MaskInfo(*dq_mask_info) if dq_mask_info is not None else None dkv_mask_info = mask_info_lib.MaskInfo(*dkv_mask_info) if dkv_mask_info is not None else None return SplashAttentionKernel( mask_info_lib.MaskInfo(*fwd_mask_info), dq_mask_info, dkv_mask_info, **kwargs, )
def _make_splash_attention( mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool, save_residuals: bool = False, mask_value: float = DEFAULT_MASK_VALUE, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int, q_seq_shards: int, residual_checkpoint_name: str | None = None, interpret: bool = False, ): if len(mask.shape) != 3: raise ValueError(f"Unexpected mask shape: {mask.shape}") if isinstance(mask, np.ndarray): mask = mask_lib.MultiHeadMask([mask_lib.NumpyMask(head_mask) for head_mask in mask]) if block_sizes is None: block_sizes = BlockSizes.get_default() process_mask_fn = mask_info_lib.process_dynamic_mask if isinstance(mask, jax.Array) else mask_info_lib.process_mask process_mask_dvk_fn = ( mask_info_lib.process_dynamic_mask_dkv if isinstance(mask, jax.Array) else mask_info_lib.process_mask_dkv ) fwd_mask_info, mask_function_fwd = process_mask_fn( mask, (block_sizes.block_q, block_sizes.block_kv), downcast_smem_data=downcast_smem_data, head_shards=head_shards, q_seq_shards=q_seq_shards, ) fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info) dq_mask_info = None dkv_mask_info = None if block_sizes.has_backward_blocks: if block_sizes.use_fused_bwd_kernel: dq_mask_info = None else: bq_dq, bkv_dq = block_sizes.block_q_dq, block_sizes.block_kv_dq dq_mask_info, mask_function_dq = process_mask_fn( mask, (bq_dq, bkv_dq), downcast_smem_data=downcast_smem_data, head_shards=head_shards, q_seq_shards=q_seq_shards, ) assert (mask_function_fwd is None) == (mask_function_dq is None) dq_mask_info = tree_util.tree_map(jnp.array, dq_mask_info) bq_dkv, bkv_dkv = block_sizes.block_q_dkv, block_sizes.block_kv_dkv dkv_mask_info, mask_function_dkv = process_mask_dvk_fn( mask, (bq_dkv, bkv_dkv), downcast_smem_data=downcast_smem_data, head_shards=head_shards, q_seq_shards=q_seq_shards, shrink_grid=not block_sizes.use_fused_bwd_kernel, ) assert (mask_function_fwd is None) == (mask_function_dkv is None) dkv_mask_info = tree_util.tree_map(jnp.array, dkv_mask_info) return SplashAttentionKernel( fwd_mask_info, dq_mask_info, dkv_mask_info, block_sizes=block_sizes, is_mqa=is_mqa, save_residuals=save_residuals, mask_value=mask_value, logits_soft_cap=logits_soft_cap, residual_checkpoint_name=residual_checkpoint_name, mask_function=mask_function_fwd, interpret=interpret, ) make_splash_mha = partial(_make_splash_attention, is_mqa=False) make_splash_mqa = partial(_make_splash_attention, is_mqa=True) make_splash_mha_single_device = partial(make_splash_mha, is_mqa=False, head_shards=1, q_seq_shards=1) make_splash_mqa_single_device = partial(make_splash_mha, is_mqa=True, head_shards=1, q_seq_shards=1)
[docs]@kernel_registry.register("blocksparse_attention", Platform.PALLAS, Backend.TPU) @ejit( static_argnames=( "softmax_scale", "mask_builder", "sliding_window", "chunk_size", "causal", "fused_backward", "fwd_params", "bwd_params", "logits_soft_cap", ) ) @jaxtyping.jaxtyped(typechecker=beartype) def blocksparse_attention( query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], q_segment_ids: Int[Array, "batch seq_len"] | None = None, kv_segment_ids: Int[Array, "batch kv_len"] | None = None, q_positions: Int[Array, "batch seq_len"] | None = None, kv_positions: Int[Array, "batch kv_len"] | None = None, softmax_aux: Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None = None, attention_mask: Bool[Array, "batch num_heads_or_1 seq_len kv_len"] | Int[Array, "batch num_heads_or_1 seq_len kv_len"] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple["SparseMask"] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], "Mask"] | Callable[[], "SparseMask"] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, ) -> Float[Array, "batch num_heads seq_len vhead_dim"]: """Pallas TPU block-sparse attention kernel implementation. Computes attention over sparse block patterns using Pallas kernels optimized for TPU execution. Args: query: Query tensor [batch num_heads seq_len head_dim] key: Key tensor [batch kv_num_heads kv_len head_dim] value: Value tensor [batch kv_num_heads kv_len vhead_dim] q_segment_ids: Optional query segment ids [batch, seq_len] kv_segment_ids: Optional KV segment ids [batch, kv_len] q_positions: Optional query position indices [batch, seq_len] (not implemented for TPU) kv_positions: Optional KV position indices [batch, kv_len] (not implemented for TPU) softmax_aux: Optional auxiliary softmax values for attention sinks bias: Optional attention bias [batch num_heads seq_len head_dim] sequence_parallelism_mesh_axis_name: Optional mesh axis name for sequence parallelism logits_soft_cap: Optional soft capping value for attention logits. When specified, applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large, improving numerical stability (Gemma-2 style). Gradients are computed with proper Jacobian. qkv_layouts: Optional pre-computed attention mask layouts softmax_scale: Attention score scaling factor (default: 1/sqrt(head_dim)) mask_builder: Custom mask builder function sliding_window: Sliding window size. Can be: - int: symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric - None: no sliding window chunk_size: Size of chunks for chunked causal attention (like Llama4) - int: enable chunked causal mask with specified chunk size - None: no chunking causal: Whether to use causal masking (default True) fused_backward: Whether to use fused backward kernel Returns: Attention output [batch num_heads seq_len vhead_dim] """ if q_positions is not None and q_segment_ids is None: raise NotImplementedError("`q_positions` is not implemented for tpu-pallas (gpu-triton and xla only).") if kv_positions is not None and kv_segment_ids is None: raise NotImplementedError("`kv_positions` is not implemented for tpu-pallas (gpu-triton and xla only).") if bias is not None: raise NotImplementedError("`bias` is not implemented for tpu-pallas (gpu-triton and xla only).") if sequence_parallelism_mesh_axis_name is not None: raise NotImplementedError( "`sequence_parallelism_mesh_axis_name` is not implemented for tpu-pallas (gpu-triton and xla only)." ) if qkv_layouts is not None: raise NotImplementedError("`qkv_layouts` is not implemented for tpu-pallas (gpu-triton and xla only).") query_length = query.shape[2] kv_length = value.shape[2] if fwd_params is None: fwd_params = FwdParams(q_blocksize=512, kv_blocksize=512, num_stages=2, num_warps=4) if bwd_params is None: bwd_params = BwdParams(q_blocksize=1024, kv_blocksize=1024, num_stages=2, num_warps=4) if attention_mask is not None and (q_segment_ids is None or kv_segment_ids is None): from ejkernel.types.mask import mask_to_segment_ids inferred_q_seg, inferred_kv_seg = mask_to_segment_ids(attention_mask) if q_segment_ids is None: q_segment_ids = inferred_q_seg if kv_segment_ids is None: kv_segment_ids = inferred_kv_seg if q_segment_ids is not None and kv_segment_ids is None: raise ValueError("If `q_segment_ids` is provided, `kv_segment_ids` must also be provided.") if kv_segment_ids is not None and q_segment_ids is None: raise ValueError("If `kv_segment_ids` is provided, `q_segment_ids` must also be provided.") if mask_builder is None: def mask_builder(q_len: int, kv_len: int, num_heads: int, head_idx: int, num_reps: int) -> Mask: if chunk_size is not None: return ChunkedCausalMask((q_len, kv_len), chunk_size=chunk_size) elif sliding_window is not None: if isinstance(sliding_window, int): left_window = right_window = sliding_window else: left_window, right_window = sliding_window local_mask = LocalMask(shape=(q_len, kv_len), window_size=(left_window, right_window), offset=0) if causal: causal_mask = CausalMask((q_len, kv_len)) return causal_mask & local_mask else: return local_mask elif causal: return CausalMask((q_len, kv_len)) else: return FullMask((q_len, kv_len)) block_sizes = BlockSizes( block_q=min(fwd_params.q_blocksize, query_length), block_kv_compute=min(fwd_params.kv_blocksize, kv_length), block_kv=min(fwd_params.kv_blocksize, kv_length), block_q_dkv=min(bwd_params.kv_blocksize, query_length), block_kv_dkv=min(bwd_params.kv_blocksize, kv_length), block_kv_dkv_compute=min(bwd_params.kv_blocksize, kv_length), block_q_dq=min(bwd_params.kv_blocksize, query_length), block_kv_dq=min(bwd_params.kv_blocksize, kv_length), use_fused_bwd_kernel=fused_backward, ) if softmax_scale is None: softmax_scale = query.shape[-1] ** -0.5 assert query_length != 1 mask = MultiHeadMask( [mask_builder(query_length, kv_length, query.shape[-3], ox, query.shape[-2]) for ox in range(query.shape[-3])] ) def attn_static_fn( q, k, v, q_segment_ids, kv_segment_ids, softmax_aux, ): segment_ids = None if kv_segment_ids is not None and q_segment_ids is not None: segment_ids = SegmentIds(q_segment_ids, kv_segment_ids) return make_splash_mha( mask=mask, block_sizes=block_sizes, logits_soft_cap=logits_soft_cap, head_shards=1, q_seq_shards=1, )( q=q, k=k, v=v, segment_ids=segment_ids, sinks=softmax_aux, ) attn_fn = jax.vmap( attn_static_fn, in_axes=( 0, 0, 0, 0 if q_segment_ids is not None else None, 0 if kv_segment_ids is not None else None, None, ), ) return attn_fn( query * softmax_scale, key, value, q_segment_ids, kv_segment_ids, softmax_aux, )