ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel#

Implementation of Sparse Flash Attention, a.k.a. “Splash” attention.

class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.BlockSizes(block_q: int, block_kv: int, block_kv_compute: int | None = None, block_q_dkv: int | None = None, block_kv_dkv: int | None = None, block_kv_dkv_compute: int | None = None, block_q_dq: int | None = None, block_kv_dq: int | None = None, use_fused_bwd_kernel: bool = False, q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR, k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR, v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR)[source]#

Bases: object

Tile sizes parameterizing SplashAttention kernels.

Those parameters have negligible effect on numerics, but affect performance greatly.

Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to blocksparse_attention attention always takes the head dimension as the minormost one.

block_kv: int#
block_kv_compute: int | None#
block_kv_dkv: int | None#
block_kv_dkv_compute: int | None#
block_kv_dq: int | None#
block_q: int#
block_q_dkv: int | None#
block_q_dq: int | None#
classmethod get_default()[source]#
property has_backward_blocks: bool#
k_layout: QKVLayout#
q_layout: QKVLayout#
use_fused_bwd_kernel: bool#
v_layout: QKVLayout#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.QKVLayout(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: IntEnum

HEAD_DIM_MINOR = 1#
SEQ_MINOR = 2#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds(q: jax.Array, kv: jax.Array)[source]#

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.

The static mask (e.g. causal) is “and-ed” with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations. .. attribute:: q

segment ids along the Q sequence

type

jax.jaxlib._jax.Array

kv#

segment ids along the KV sequence

Type

jax.jaxlib._jax.Array

kv: Array#

Alias for field number 1

q: Array#

Alias for field number 0

class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SplashAttentionKernel(fwd_mask_info: MaskInfo, dq_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, dkv_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, **kwargs)[source]#

Bases: object

manual_sharding_spec(sharding: NamedSharding)[source]#

Returns a value that can be used as a shard_map partition spec for the kernel.

tree_flatten()[source]#
classmethod tree_unflatten(kwargs, values)[source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.attention_reference(mask: Array, q: Array, k: Array, v: Array, segment_ids: ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds | None, sinks: jax.jaxlib._jax.Array | None = None, *, mask_value: float = -2.381976426469702e+38, save_residuals: bool = False, custom_type: str = 'flash', logits_soft_cap: float | None = None) Union[Array, tuple[jax.jaxlib._jax.Array, tuple[jax.jaxlib._jax.Array]]][source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.attention_reference_custom(mask: Array, q: Array, k: Array, v: Array, segment_ids: ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds | None, sinks: jax.jaxlib._jax.Array | None = None, *, mask_value: float = -2.381976426469702e+38, save_residuals: bool = False, custom_type: str = 'flash', logits_soft_cap: float | None = None)[source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, attention_mask: Bool[Array, 'batch num_heads_or_1 seq_len kv_len'] | Int[Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False) Float[Array, 'batch num_heads seq_len vhead_dim'][source]#

Pallas TPU block-sparse attention kernel implementation.

Computes attention over sparse block patterns using Pallas kernels optimized for TPU execution.

Parameters
  • query – Query tensor [batch num_heads seq_len head_dim]

  • key – Key tensor [batch kv_num_heads kv_len head_dim]

  • value – Value tensor [batch kv_num_heads kv_len vhead_dim]

  • q_segment_ids – Optional query segment ids [batch, seq_len]

  • kv_segment_ids – Optional KV segment ids [batch, kv_len]

  • q_positions – Optional query position indices [batch, seq_len] (not implemented for TPU)

  • kv_positions – Optional KV position indices [batch, kv_len] (not implemented for TPU)

  • softmax_aux – Optional auxiliary softmax values for attention sinks

  • bias – Optional attention bias [batch num_heads seq_len head_dim]

  • sequence_parallelism_mesh_axis_name – Optional mesh axis name for sequence parallelism

  • logits_soft_cap – Optional soft capping value for attention logits. When specified, applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large, improving numerical stability (Gemma-2 style). Gradients are computed with proper Jacobian.

  • qkv_layouts – Optional pre-computed attention mask layouts

  • softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))

  • mask_builder – Custom mask builder function

  • sliding_window – Sliding window size. Can be: - int: symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric - None: no sliding window

  • chunk_size – Size of chunks for chunked causal attention (like Llama4) - int: enable chunked causal mask with specified chunk size - None: no chunking

  • causal – Whether to use causal masking (default True)

  • fused_backward – Whether to use fused backward kernel

Returns

Attention output [batch num_heads seq_len vhead_dim]

ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.flash_attention_kernel(data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, sinks_ref, mask_ref, q_sequence_ref, m_scratch_ref, l_scratch_ref, o_scratch_ref, o_ref, logsumexp_ref=None, *, mask_value: float, grid_width: int, bq: int, bkv: int, bkv_compute: int, head_dim_v: int, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, logits_soft_cap: float | None, mask_function: collections.abc.Callable[[...], jax.jaxlib._jax.Array] | None)[source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.from_head_minor(vals: tuple[Any, ...], layout: QKVLayout)[source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.get_kernel_name(block_metadata: Mapping[str, Any], is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str) str[source]#

Returns a unique name for all SplashAttention kernel variants.

ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_attention_reference(mask: ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.Mask | numpy.ndarray, is_mqa: bool, backward_impl: str = 'vanilla', **params: Any) Callable[source]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_masked_mha_reference(mask: mask_lib.Mask | np.ndarray, *, is_mqa: bool = False, backward_impl: str = 'vanilla', **params: Any) Callable#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_masked_mqa_reference(mask: mask_lib.Mask | np.ndarray, *, is_mqa: bool = True, backward_impl: str = 'vanilla', **params: Any) Callable#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mha(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = False, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int, q_seq_shards: int, residual_checkpoint_name: str | None = None, interpret: bool = False)#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mha_single_device(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = False, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, residual_checkpoint_name: str | None = None, interpret: bool = False)#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mqa(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = True, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int, q_seq_shards: int, residual_checkpoint_name: str | None = None, interpret: bool = False)#
ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mqa_single_device(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = True, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, residual_checkpoint_name: str | None = None, interpret: bool = False)#