ejkernel.kernels._triton.blocksparse_attention._interface

ejkernel.kernels._triton.blocksparse_attention._interface#

Block-sparse Attention implementation using Triton kernels.

This module provides an efficient implementation of block-sparse attention, where the attention pattern is defined by a sparse mask at the block level rather than at the token level. This enables significant computational and memory savings while maintaining expressive attention patterns.

Block-sparse attention operates on fixed-size blocks of queries and keys, where entire blocks either attend to each other or are masked out. This coarse-grained sparsity enables: 1. Reduced computational complexity: O(N²/B²) for block size B 2. Memory efficiency: Only compute attention for active blocks 3. Flexible attention patterns: Causal, local, strided, custom patterns 4. Better cache utilization: Block-level operations improve memory access

Key concepts: - Blocks: Fixed-size chunks of the sequence (typically 64 or 128 tokens) - Sparse Mask: Binary mask indicating which query blocks attend to which key blocks - QKV Layouts: Pre-computed sparsity patterns defining block connectivity - Load Balancing: Ensures even distribution of work across GPU threads

Common sparse patterns supported: - Causal masking (lower triangular) - Sliding window attention (local context) - Strided patterns (e.g., every k-th block) - Custom patterns via mask_builder

The implementation uses the SparseMask dataclass to define attention patterns and automatically handles gradient computation through custom VJP definitions.

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.blocksparse_attention import blocksparse_attention
>>>
>>> batch, num_heads, seq_len, head_dim = 2, 12, 2048, 64
>>> q = jnp.ones((batch, num_heads, seq_len, head_dim))
>>> k = jnp.ones((batch, num_heads, seq_len, head_dim))
>>> v = jnp.ones((batch, num_heads, seq_len, head_dim))
>>>
>>>
>>> output = blocksparse_attention(
...     q, k, v,
...     q_blocksize=128,
...     kv_blocksize=128,
...     causal=True
... )
>>>
>>>
>>> output = blocksparse_attention(
...     q, k, v,
...     sliding_window=(256, 256),
...     q_blocksize=64,
...     kv_blocksize=64
... )
Reference:

Sparse Attention patterns and efficient implementations https://arxiv.org/abs/1904.10509

ejkernel.kernels._triton.blocksparse_attention._interface.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, attention_mask: Bool[Array, 'batch num_heads_or_1 seq_len kv_len'] | Int[Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False) Float[Array, 'batch num_heads seq_len vhead_dim'][source]#

Triton block-sparse attention kernel implementation.

Computes attention over sparse block patterns using optimized Triton kernels for GPU execution.

Parameters
  • query – Query tensor [batch, num_heads, seq_len, head_dim]

  • key – Key tensor [batch, kv_num_heads, kv_len, head_dim]

  • value – Value tensor [batch, kv_num_heads, kv_len, vhead_dim]

  • q_segment_ids – Optional segment IDs for queries [batch, seq_len]. If not provided and attention_mask is given, will be inferred from attention_mask.

  • kv_segment_ids – Optional segment IDs for keys/values [batch, kv_len]. If not provided and attention_mask is given, will be inferred from attention_mask.

  • q_positions – Optional position indices for queries [batch, seq_len]

  • kv_positions – Optional position indices for keys/values [batch, kv_len]

  • softmax_aux – Optional auxiliary softmax values (e.g., attention sinks)

  • bias – Optional attention bias [batch, num_heads, seq_len, head_dim]

  • attention_mask – Optional attention mask [batch, seq_len, kv_len] or [batch, num_heads, seq_len, kv_len]. Used to automatically infer q_segment_ids and kv_segment_ids if they are not provided. Tokens with True/1 can attend to each other, False/0 indicates masking. This provides a convenient way to use attention masks without manually creating segment IDs.

  • sequence_parallelism_mesh_axis_name – Optional axis name for sequence parallelism

  • logits_soft_cap – Optional soft capping value for attention logits. When specified, applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large, improving numerical stability (Gemma-2 style). Gradients are computed with proper Jacobian.

  • qkv_layouts – Optional pre-computed attention mask layouts

  • softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))

  • mask_builder – Function to build custom sparse mask patterns

  • sliding_window – Window size for local attention, int for symmetric or (left, right) tuple

  • chunk_size – Alternative to separate q_blocksize/kv_blocksize

  • causal – Whether to apply causal masking (default: True)

  • fused_backward – Use fused backward pass (default: False)

Returns

Attention output [batch, num_heads, seq_len, head_dim]

Examples

>>> output = blocksparse_attention(q, k, v)