ejkernel.modules.operations.blocksparse_attention#

Block-sparse attention module with automatic optimization.

This module implements block-sparse attention, which applies attention only to predefined blocks of the attention matrix, significantly reducing computational cost for long sequences while maintaining important attention patterns.

The block-sparse pattern is defined by a mask builder function that determines which blocks should be computed. This is particularly useful for document-level attention, local attention patterns, and sparse attention architectures.

class ejkernel.modules.operations.blocksparse_attention.BlockSparseAttention[source]#

Bases: Kernel[BlockSparseAttentionConfig, Array]

Block-sparse attention kernel with custom optimization logic.

Implements attention computation over sparse block patterns, computing attention only for specified blocks rather than the full attention matrix. This reduces computational complexity from O(N^2) to O(N * B) where B is the average number of blocks per row.

Features:
  • Configurable sparse block patterns via mask builder

  • Support for causal masking and sliding windows

  • Automatic platform/backend selection

  • Optional autotuning for optimal block sizes

  • Gradient support for training with custom VJP

  • Logit soft capping with tanh activation for numerical stability (Gemma-2 style)

  • Separate forward/backward block sizes for performance tuning

The mask builder function defines which blocks to compute, enabling patterns like:
  • Local attention (nearby tokens only)

  • Global + local (attending to special tokens + local context)

  • Strided patterns (every nth block)

  • Custom patterns based on document structure

Example

>>> from ejkernel.modules.operations import BlockSparseAttention
>>> from ejkernel.modules import create_default_executor
>>>
>>> executor = create_default_executor()
>>> attn = BlockSparseAttention()
>>>
>>>
>>> def local_mask(q_idx, k_idx, q_size, k_size, window):
...
...     pass
>>>
>>> output = executor(
...     attn,
...     query, key, value,
...     mask_builder=local_mask,
...     chunk_size=128
... )
candidate_cfgs(inv: Invocation[BlockSparseAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates multiple block size configurations for benchmarking to find the optimal tiling parameters for the given input shapes.

Parameters

inv – Invocation object with arguments and metadata

Returns

Iterable of candidate configurations to test during autotuning

Note

The autotuning system will benchmark each candidate and select the fastest one for the given input configuration.

candidate_cfgs_gpu(inv: Invocation[BlockSparseAttentionConfig, Array])[source]#

Generate GPU-optimized candidate configurations for autotuning (Triton).

Heuristics: - q/kv blocks in {32, 64, 128, 256} depending on head_dim - If sliding_window is set, favor kv blocks ≲ window size (rounded) - num_warps: 2-8 based on head_dim and block sizes - num_stages: 2-4 (bigger when kv block is large) - Backward block sizes smaller to reduce register pressure

candidate_cfgs_shard_map_gpu(inv: Invocation[BlockSparseAttentionConfig, Array])#

Generate GPU-optimized candidate configurations for autotuning (Triton).

Heuristics: - q/kv blocks in {32, 64, 128, 256} depending on head_dim - If sliding_window is set, favor kv blocks ≲ window size (rounded) - num_warps: 2-8 based on head_dim and block sizes - num_stages: 2-4 (bigger when kv block is large) - Backward block sizes smaller to reduce register pressure

candidate_cfgs_shard_map_tpu(inv: Invocation[BlockSparseAttentionConfig, Array])#

Generate TPU-optimized candidate configurations for autotuning (Pallas).

candidate_cfgs_shard_map_xla(inv: Invocation[BlockSparseAttentionConfig, Array])#
candidate_cfgs_tpu(inv: Invocation[BlockSparseAttentionConfig, Array])[source]#

Generate TPU-optimized candidate configurations for autotuning (Pallas).

candidate_cfgs_xla(inv: Invocation[BlockSparseAttentionConfig, Array])[source]#
create_shard_map_wrapper(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: BlockSparseAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper specifically for blocksparse attention.

Parameters
  • mesh – JAX device mesh

  • in_specs – Input partition specs (must match length of tensor args)

  • out_specs – Output partition spec

  • query – Input tensors to be sharded

  • key – Input tensors to be sharded

  • value – Input tensors to be sharded

  • args (All other) – Blocksparse attention parameters

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: BlockSparseAttentionConfig)[source]#

Get kernel implementation from registry based on configuration.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for block-sparse attention

Raises

ValueError – If no matching implementation is found for the configuration

heuristic_cfg(inv: Invocation[BlockSparseAttentionConfig, Array]) BlockSparseAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

heuristic_cfg_gpu(inv: Invocation[BlockSparseAttentionConfig, Array]) BlockSparseAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

heuristic_cfg_tpu(inv: Invocation[BlockSparseAttentionConfig, Array]) BlockSparseAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

run(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, *, cfg: BlockSparseAttentionConfig) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute block-sparse attention with the given configuration.

Parameters
  • query – Query tensor [batch, num_heads, seq_len, head_dim]

  • key – Key tensor [batch, kv_num_heads, kv_len, head_dim]

  • value – Value tensor [batch, kv_num_heads, kv_len, vhead_dim]

  • q_segment_ids – Segment IDs for queries to handle multiple sequences [batch, seq_len]

  • kv_segment_ids – Segment IDs for keys/values [batch, kv_len]

  • softmax_aux – Auxiliary values added to attention scores (e.g., for attention sinks)

  • logits_soft_cap – Optional soft cap value to bound attention logits

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • mask_builder – Function that builds the sparse mask pattern. Takes (q_idx, k_idx, q_size, k_size, window_size) and returns a Mask object

  • sliding_window – Window size for local attention, int for symmetric or (left, right) tuple

  • chunk_size – Overall chunk size (alternative to separate query/key chunk sizes)

  • causal – Whether to apply causal masking (default: True)

  • fused_backward – Use fused backward pass for improved gradient computation

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Configuration object specifying platform/backend and kernel parameters

Returns

Attention output tensor [batch, seq_len_q, num_heads, head_dim]

Note

The mask_builder function is critical for defining sparsity patterns. It should return a mask indicating which blocks to compute.

ejkernel.modules.operations.blocksparse_attention.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, /, *, mask_info: MaskInfo | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, purify: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: BlockSparseAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None) Float[Array, 'batch kv_num_heads kv_len vhead_dim'][source]#

Execute block-sparse attention with automatic optimization.

Performs efficient attention computation over sparse block patterns, significantly reducing memory and computation compared to dense attention while maintaining flexibility through custom mask builders.

Parameters
  • query – Query tensor [batch, num_heads, seq_len, head_dim]

  • key – Key tensor [batch, kv_num_heads, kv_len, head_dim]

  • value – Value tensor [batch, kv_num_heads, kv_len, vhead_dim]

  • mask_info – Optional MaskInfo containing attention mask, segment IDs, and position indices

  • q_positions – Optional query position indices [batch, seq_len] for positional embeddings. If None and mask_info is provided, will use positions from mask_info.

  • kv_positions – Optional key-value position indices [batch, kv_len] for positional embeddings. If None and mask_info is provided, will use positions from mask_info.

  • softmax_aux – Optional auxiliary attention values (e.g., attention sinks)

  • logits_soft_cap – Optional soft capping for attention logits

  • query_chunk_size – Query chunk size for block tiling (default: 128)

  • key_chunk_size – Key chunk size for block tiling (default: 128)

  • softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))

  • mask_builder – Callable defining sparse pattern. Signature: (q_idx, k_idx, q_size, k_size, window) -> Mask

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • chunk_size – Alternative to separate query_chunk_size/key_chunk_size

  • causal – Apply causal masking (default: True)

  • fused_backward – Use fused backward pass (default: False)

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional configuration override

  • mesh – JAX device mesh for shard_map execution (optional)

  • in_specs – Input partition specs for shard_map (optional)

  • out_specs – Output partition spec for shard_map (optional)

Returns

Attention output [batch, kv_num_heads, kv_len, vhead_dim]

Example

>>> from ejkernel.modules.operations import blocksparse_attention
>>>
>>>
>>> output = blocksparse_attention(query, key, value, causal=True)
>>>
>>>
>>> def local_plus_global(q_idx, k_idx, q_size, k_size, window):
...
...     return create_local_global_mask(q_idx, k_idx, window)
>>>
>>> output = blocksparse_attention(
...     query, key, value,
...     mask_builder=local_plus_global,
...     sliding_window=256
... )
>>>
>>>
>>> output = blocksparse_attention(
...     query, key, value,
...     platform="triton"
... )

Note

Block-sparse attention is particularly effective for: - Long document processing where full attention is prohibitive - Architectures with specific attention patterns (e.g., Longformer) - Scenarios where custom sparsity patterns are needed