ejkernel.kernels._triton.blocksparse_attention._interface#
Block-sparse Attention implementation using Triton kernels.
This module provides an efficient implementation of block-sparse attention, where the attention pattern is defined by a sparse mask at the block level rather than at the token level. This enables significant computational and memory savings while maintaining expressive attention patterns.
Block-sparse attention operates on fixed-size blocks of queries and keys, where entire blocks either attend to each other or are masked out. This coarse-grained sparsity enables: 1. Reduced computational complexity: O(N²/B²) for block size B 2. Memory efficiency: Only compute attention for active blocks 3. Flexible attention patterns: Causal, local, strided, custom patterns 4. Better cache utilization: Block-level operations improve memory access
Key concepts: - Blocks: Fixed-size chunks of the sequence (typically 64 or 128 tokens) - Sparse Mask: Binary mask indicating which query blocks attend to which key blocks - QKV Layouts: Pre-computed sparsity patterns defining block connectivity - Load Balancing: Ensures even distribution of work across GPU threads
Common sparse patterns supported: - Causal masking (lower triangular) - Sliding window attention (local context) - Strided patterns (e.g., every k-th block) - Custom patterns via mask_builder
The implementation uses the SparseMask dataclass to define attention patterns and automatically handles gradient computation through custom VJP definitions.
Example
>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.blocksparse_attention import blocksparse_attention
>>>
>>> batch, num_heads, seq_len, head_dim = 2, 12, 2048, 64
>>> q = jnp.ones((batch, num_heads, seq_len, head_dim))
>>> k = jnp.ones((batch, num_heads, seq_len, head_dim))
>>> v = jnp.ones((batch, num_heads, seq_len, head_dim))
>>>
>>>
>>> output = blocksparse_attention(
... q, k, v,
... q_blocksize=128,
... kv_blocksize=128,
... causal=True
... )
>>>
>>>
>>> output = blocksparse_attention(
... q, k, v,
... sliding_window=(256, 256),
... q_blocksize=64,
... kv_blocksize=64
... )
- Reference:
Sparse Attention patterns and efficient implementations https://arxiv.org/abs/1904.10509
- ejkernel.kernels._triton.blocksparse_attention._interface.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, attention_mask: Bool[Array, 'batch num_heads_or_1 seq_len kv_len'] | Int[Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False) Float[Array, 'batch num_heads seq_len vhead_dim'][source]#
Triton block-sparse attention kernel implementation.
Computes attention over sparse block patterns using optimized Triton kernels for GPU execution.
- Parameters
query – Query tensor [batch, num_heads, seq_len, head_dim]
key – Key tensor [batch, kv_num_heads, kv_len, head_dim]
value – Value tensor [batch, kv_num_heads, kv_len, vhead_dim]
q_segment_ids – Optional segment IDs for queries [batch, seq_len]. If not provided and attention_mask is given, will be inferred from attention_mask.
kv_segment_ids – Optional segment IDs for keys/values [batch, kv_len]. If not provided and attention_mask is given, will be inferred from attention_mask.
q_positions – Optional position indices for queries [batch, seq_len]
kv_positions – Optional position indices for keys/values [batch, kv_len]
softmax_aux – Optional auxiliary softmax values (e.g., attention sinks)
bias – Optional attention bias [batch, num_heads, seq_len, head_dim]
attention_mask – Optional attention mask [batch, seq_len, kv_len] or [batch, num_heads, seq_len, kv_len]. Used to automatically infer q_segment_ids and kv_segment_ids if they are not provided. Tokens with True/1 can attend to each other, False/0 indicates masking. This provides a convenient way to use attention masks without manually creating segment IDs.
sequence_parallelism_mesh_axis_name – Optional axis name for sequence parallelism
logits_soft_cap – Optional soft capping value for attention logits. When specified, applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large, improving numerical stability (Gemma-2 style). Gradients are computed with proper Jacobian.
qkv_layouts – Optional pre-computed attention mask layouts
softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))
mask_builder – Function to build custom sparse mask patterns
sliding_window – Window size for local attention, int for symmetric or (left, right) tuple
chunk_size – Alternative to separate q_blocksize/kv_blocksize
causal – Whether to apply causal masking (default: True)
fused_backward – Use fused backward pass (default: False)
- Returns
Attention output [batch, num_heads, seq_len, head_dim]
Examples
>>> output = blocksparse_attention(q, k, v)