ejkernel.kernels._triton.native_sparse_attention._interface#
Native Sparse Attention (NSA) implementation using Triton kernels.
This module implements Native Sparse Attention, a hybrid attention mechanism that combines compressed attention over coarse-grained blocks with selective fine-grained attention to important tokens. This approach achieves significant computational savings while maintaining model quality.
NSA Architecture:#
NSA consists of two parallel attention pathways that are gated and combined:
Compressed Attention (g_cmp pathway): - Keys and values are mean-pooled into blocks (e.g., 64 tokens -> 1 block) - Each query attends to these compressed block representations - Provides global context with O(N²/B) complexity for block size B - Used to identify which blocks contain important information
Selected Attention (g_slc pathway): - Based on compressed attention scores, select top-K blocks per query - Perform full fine-grained attention only to selected blocks - Provides detailed local attention with O(N*K*B) complexity - Focuses computational resources on relevant regions
- The final output is a gated combination:
output = g_slc * selected_attn + g_cmp * compressed_attn
where g_slc and g_cmp are learned gating values that balance the two pathways.
Key Benefits:#
Adaptive sparsity: Automatically learns which blocks are important
Flexible trade-off: Balance between global context and local detail
Reduced computation: O(N²/B + N*K*B) instead of O(N²)
Maintained quality: Selective attention focuses on relevant tokens
Example complexity for N=4096, B=64, K=16: - Standard attention: 16M operations - NSA: ~1M operations (16x reduction)
Implementation Details:#
Requires Grouped Query Attention (GQA) with group size multiple of 16
Supports variable-length sequences via cu_seqlens
Block indices can be pre-computed or learned via compression pathway
Gradients flow through both compression and selection mechanisms
Example
>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.native_sparse_attention import native_sparse_attention
>>>
>>> batch, seq_len, num_q_heads, num_kv_heads, head_dim = 2, 2048, 32, 32, 64
>>> q = jnp.ones((batch, seq_len, num_q_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_kv_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_kv_heads, head_dim))
>>>
>>>
>>> g_cmp = jnp.ones((batch, seq_len, num_q_heads))
>>> g_slc = jnp.ones((batch, seq_len, num_q_heads))
>>>
>>>
>>> output = native_sparse_attention(
... q, k, v,
... g_cmp=g_cmp,
... g_slc=g_slc,
... block_counts=16,
... block_size=64
... )
- Reference:
Efficient Attention via Control Variates https://arxiv.org/abs/2302.04542
- ejkernel.kernels._triton.native_sparse_attention._interface.apply_native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], block_indices: Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'], block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, token_indices: jaxtyping.Int[jaxlib._jax.Array, 'total_tokens'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#
Applies NativeSparseAttention using a pre-computed sparse block pattern.
This function is a user-facing wrapper around the core JIT-compiled _apply_nsa function. It optionally prepares token indices for variable-length sequence processing.
- Parameters
query – Query tensor.
key – Key tensor.
value – Value tensor.
block_indices – A tensor specifying the indices of the key/value blocks that each query should attend to.
block_counts – The number of blocks each query attends to. Can be an integer (for uniform sparsity) or a tensor.
block_size – The size of each key/value block.
softmax_scale – The scaling factor for the attention scores.
cu_seqlens – Optional cumulative sequence lengths for variable-length sequences.
token_indices – Optional pre-computed token indices for variable-length sequences. If None and cu_seqlens is provided, they are computed internally.
- Returns
The output tensor from the sparse attention computation.
- ejkernel.kernels._triton.native_sparse_attention._interface.native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#
NSA is a sparse attention mechanism that combines two components: 1. Compressed Attention: A coarse-grained attention over mean-pooled
(compressed) key-value blocks. This provides a global context summary.
Selected Attention: A fine-grained, sparse attention where each query attends to a small subset of the original key-value blocks.
The key idea is that the selection of blocks for the second component can be determined efficiently using the compressed representations from the first. The final output is a gated combination of these two components.
- Parameters
query – Query tensor of shape (batch_size, sequence, query_heads, dimk).
key – Key tensor of shape (batch_size, sequence, kvheads, dimk). GQA is enforced, where the ratio of query heads (query_heads) to key/value heads (kvheads) must be a multiple of 16.
value – Value tensor of shape (batch_size, sequence, kvheads, dimv).
g_cmp – Optional gate tensor for compressed attention, shape (batch_size, sequence, query_heads). If provided, the compressed attention component is computed.
g_slc – Optional gate tensor for selected attention, shape (batch_size, sequence, query_heads).
block_indices – Optional tensor of pre-computed block indices for selected attention, shape (batch_size, sequence, kvheads, S). S is the number of selected blocks (block_counts). If g_cmp is provided, this argument is ignored, and block indices are computed dynamically via top-k selection over the compressed keys. If g_cmp is NOT provided, this argument is required.
block_counts – Number of blocks to select for each query. Defaults to 16.
block_size – The size of each attention block. Defaults to 64.
softmax_scale – Scale factor for attention scores. Defaults to 1 / sqrt(dimk) or dimk**-0.5.
cu_seqlens – Cumulative sequence lengths of shape (N+1) for variable-length training. If provided, batch size batch_size must be 1.
- Returns
The output tensor of shape (batch_size, sequence, query_heads, dimv).