ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel#
Implementation of Sparse Flash Attention, a.k.a. “Splash” attention.
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.BlockSizes(block_q: int, block_kv: int, block_kv_compute: int | None = None, block_q_dkv: int | None = None, block_kv_dkv: int | None = None, block_kv_dkv_compute: int | None = None, block_q_dq: int | None = None, block_kv_dq: int | None = None, use_fused_bwd_kernel: bool = False, q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR, k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR, v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR)[source]#
Bases:
objectTile sizes parameterizing SplashAttention kernels.
Those parameters have negligible effect on numerics, but affect performance greatly.
Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to blocksparse_attention attention always takes the head dimension as the minormost one.
- block_kv: int#
- block_kv_compute: int | None#
- block_kv_dkv: int | None#
- block_kv_dkv_compute: int | None#
- block_kv_dq: int | None#
- block_q: int#
- block_q_dkv: int | None#
- block_q_dq: int | None#
- property has_backward_blocks: bool#
- use_fused_bwd_kernel: bool#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.QKVLayout(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
IntEnum- HEAD_DIM_MINOR = 1#
- SEQ_MINOR = 2#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds(q: jax.Array, kv: jax.Array)[source]#
Bases:
NamedTupleSegmentIds for Q and KV sequences.
SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.
The static mask (e.g. causal) is “and-ed” with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations. .. attribute:: q
segment ids along the Q sequence
- type
jax.jaxlib._jax.Array
- kv#
segment ids along the KV sequence
- Type
jax.jaxlib._jax.Array
- kv: Array#
Alias for field number 1
- q: Array#
Alias for field number 0
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SplashAttentionKernel(fwd_mask_info: MaskInfo, dq_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, dkv_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, **kwargs)[source]#
Bases:
object- manual_sharding_spec(sharding: NamedSharding)[source]#
Returns a value that can be used as a shard_map partition spec for the kernel.
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.attention_reference(mask: Array, q: Array, k: Array, v: Array, segment_ids: ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds | None, sinks: jax.jaxlib._jax.Array | None = None, *, mask_value: float = -2.381976426469702e+38, save_residuals: bool = False, custom_type: str = 'flash', logits_soft_cap: float | None = None) Union[Array, tuple[jax.jaxlib._jax.Array, tuple[jax.jaxlib._jax.Array]]][source]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.attention_reference_custom(mask: Array, q: Array, k: Array, v: Array, segment_ids: ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.SegmentIds | None, sinks: jax.jaxlib._jax.Array | None = None, *, mask_value: float = -2.381976426469702e+38, save_residuals: bool = False, custom_type: str = 'flash', logits_soft_cap: float | None = None)[source]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, attention_mask: Bool[Array, 'batch num_heads_or_1 seq_len kv_len'] | Int[Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False) Float[Array, 'batch num_heads seq_len vhead_dim'][source]#
Pallas TPU block-sparse attention kernel implementation.
Computes attention over sparse block patterns using Pallas kernels optimized for TPU execution.
- Parameters
query – Query tensor [batch num_heads seq_len head_dim]
key – Key tensor [batch kv_num_heads kv_len head_dim]
value – Value tensor [batch kv_num_heads kv_len vhead_dim]
q_segment_ids – Optional query segment ids [batch, seq_len]
kv_segment_ids – Optional KV segment ids [batch, kv_len]
q_positions – Optional query position indices [batch, seq_len] (not implemented for TPU)
kv_positions – Optional KV position indices [batch, kv_len] (not implemented for TPU)
softmax_aux – Optional auxiliary softmax values for attention sinks
bias – Optional attention bias [batch num_heads seq_len head_dim]
sequence_parallelism_mesh_axis_name – Optional mesh axis name for sequence parallelism
logits_soft_cap – Optional soft capping value for attention logits. When specified, applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large, improving numerical stability (Gemma-2 style). Gradients are computed with proper Jacobian.
qkv_layouts – Optional pre-computed attention mask layouts
softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))
mask_builder – Custom mask builder function
sliding_window – Sliding window size. Can be: - int: symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric - None: no sliding window
chunk_size – Size of chunks for chunked causal attention (like Llama4) - int: enable chunked causal mask with specified chunk size - None: no chunking
causal – Whether to use causal masking (default True)
fused_backward – Whether to use fused backward kernel
- Returns
Attention output [batch num_heads seq_len vhead_dim]
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.flash_attention_kernel(data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, sinks_ref, mask_ref, q_sequence_ref, m_scratch_ref, l_scratch_ref, o_scratch_ref, o_ref, logsumexp_ref=None, *, mask_value: float, grid_width: int, bq: int, bkv: int, bkv_compute: int, head_dim_v: int, q_layout: QKVLayout, k_layout: QKVLayout, v_layout: QKVLayout, logits_soft_cap: float | None, mask_function: collections.abc.Callable[[...], jax.jaxlib._jax.Array] | None)[source]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.from_head_minor(vals: tuple[Any, ...], layout: QKVLayout)[source]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.get_kernel_name(block_metadata: Mapping[str, Any], is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str) str[source]#
Returns a unique name for all SplashAttention kernel variants.
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_attention_reference(mask: ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.Mask | numpy.ndarray, is_mqa: bool, backward_impl: str = 'vanilla', **params: Any) Callable[source]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_masked_mha_reference(mask: mask_lib.Mask | np.ndarray, *, is_mqa: bool = False, backward_impl: str = 'vanilla', **params: Any) Callable#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_masked_mqa_reference(mask: mask_lib.Mask | np.ndarray, *, is_mqa: bool = True, backward_impl: str = 'vanilla', **params: Any) Callable#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mha(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = False, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int, q_seq_shards: int, residual_checkpoint_name: str | None = None, interpret: bool = False)#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mha_single_device(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = False, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, residual_checkpoint_name: str | None = None, interpret: bool = False)#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mqa(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = True, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int, q_seq_shards: int, residual_checkpoint_name: str | None = None, interpret: bool = False)#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.make_splash_mqa_single_device(mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask, *, block_sizes: BlockSizes | None = None, is_mqa: bool = True, save_residuals: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, residual_checkpoint_name: str | None = None, interpret: bool = False)#