ejkernel.kernels._xla.native_sparse_attention._interface#
Native Sparse Attention interface with block selection and compression.
This module provides the public API for Native Sparse Attention (NSA) that combines compressed global attention with fine-grained sparse block selection. Includes VJP for training support and top-k block selection algorithms.
- ejkernel.kernels._xla.native_sparse_attention._interface.apply_native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], block_indices: Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'], block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, token_indices: jaxtyping.Int[jaxlib._jax.Array, 'total_tokens'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#
Applies block-sparse attention using a pre-computed sparsity pattern with JAX/XLA.
This function implements sparse attention where each query block attends to a subset of key blocks specified by the sparsity pattern. This reduces computational complexity from O(N²) to O(N·S) where S is the sparsity (number of blocks attended).
- Parameters
query – Query tensor of shape (batch, seq_len, num_heads, head_dim).
key – Key tensor of shape (batch, seq_len, num_heads, head_dim).
value – Value tensor of shape (batch, seq_len, num_heads, head_dim).
block_indices – A tensor of shape (batch, num_heads, num_query_blocks, num_key_blocks) specifying which key blocks each query block should attend to. Each entry contains the index of a key block.
block_counts – Number of key blocks each query block attends to. Can be: - int: uniform sparsity for all query blocks - tensor [batch, num_heads, num_query_blocks]: per-block sparsity
block_size – Size of each block (both query and key blocks).
softmax_scale – Attention scaling factor. If None, defaults to 1/sqrt(head_dim).
- Returns
Attention output of shape (batch, seq_len, num_heads, head_dim).
Notes
The sequence is divided into blocks of size block_size
Each query block computes attention over selected key blocks only
Sparsity is determined by block_indices and block_counts
Useful for long-range attention with reduced computation
Examples
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64 >>> block_size = 64 >>> num_blocks = seq_len // block_size >>> >>> q = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> k = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> v = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> >>> >>> block_counts = 4 >>> block_indices = jnp.tile( ... jnp.arange(4)[None, None, None, :], ... (batch, num_heads, num_blocks, 1) ... ) >>> >>> output = apply_native_sparse_attention( ... query, key, value, block_indices, block_counts, block_size ... ) >>> output.shape (2, 1024, 8, 64)
>>> >>> def create_local_pattern(num_blocks, window=2): ... indices = [] ... for i in range(num_blocks): ... local = list(range(max(0, i-window), min(num_blocks, i+window+1))) ... ... local = local + [0] * (window*2+1 - len(local)) ... indices.append(local) ... return jnp.array(indices) >>> >>> local_indices = create_local_pattern(num_blocks, window=2) >>> local_indices = jnp.tile(local_indices[None, None, :, :], (batch, num_heads, 1, 1)) >>> output = apply_native_sparse_attention( ... query, key, value, local_indices, block_counts=5, block_size=block_size ... )
- ejkernel.kernels._xla.native_sparse_attention._interface.native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#
Native Sparse Attention (NSA) with XLA/JAX implementation.
NSA is a sparse attention mechanism that combines two components: 1. Compressed Attention: A coarse-grained attention over mean-pooled
(compressed) key-value blocks. This provides a global context summary.
Selected Attention: A fine-grained, sparse attention where each query attends to a small subset of the original key-value blocks.
The key idea is that the selection of blocks for the second component can be determined efficiently using the compressed representations from the first. The final output is a gated combination of these two components.
- Parameters
query – Query tensor of shape (batch_size, sequence, num_heads, head_dim).
key – Key tensor of shape (batch_size, sequence, num_heads, head_dim).
value – Value tensor of shape (batch_size, sequence, num_heads, head_dim).
g_cmp – Optional gate tensor for compressed attention, shape (batch_size, sequence, hidden_dim). If provided, the compressed attention component is computed.
g_slc – Optional gate tensor for selected attention, shape (batch_size, sequence, hidden_dim).
block_indices – Optional tensor of pre-computed block indices for selected attention, shape (batch_size, num_heads, num_query_blocks, block_counts). If g_cmp is provided, this argument is ignored, and block indices are computed dynamically via top-k selection over the compressed keys. If g_cmp is NOT provided, this argument is required.
block_counts – Number of blocks to select for each query. Can be: - int: uniform sparsity for all query blocks - tensor [batch, num_heads, num_query_blocks]: per-block sparsity Defaults to 16.
block_size – The size of each attention block. Defaults to 64.
softmax_scale – Scale factor for attention scores. Defaults to 1 / sqrt(head_dim).
cu_seqlens – Cumulative sequence lengths of shape (N+1) for variable-length training. If provided, batch size must be 1. Note: Variable-length sequences are not yet fully supported in XLA version.
- Returns
The output tensor of shape (batch_size, sequence, num_heads, head_dim).
Notes
The XLA implementation uses pure JAX operations without custom kernels
For variable-length sequences (cu_seqlens), this uses the mean_pooling function
The compressed attention component uses mean-pooled key/value blocks
Top-k block selection is based on attention scores from compressed keys
Examples
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64 >>> block_size = 64 >>> block_counts = 16 >>> >>> q = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> k = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> v = jnp.ones((batch, seq_len, num_heads, head_dim)) >>> >>> >>> g_cmp = jnp.ones((batch, seq_len, num_heads * head_dim)) >>> output = native_sparse_attention( ... query, key, value, g_cmp=g_cmp, block_counts=block_counts, block_size=block_size ... ) >>> output.shape (2, 1024, 8, 64) >>> >>> >>> num_blocks = seq_len // block_size >>> block_indices = jnp.tile( ... jnp.arange(block_counts)[None, None, None, :], ... (batch, num_heads, num_blocks, 1) ... ) >>> output = native_sparse_attention( ... query, key, value, block_indices=block_indices, block_counts=block_counts, block_size=block_size ... ) >>> output.shape (2, 1024, 8, 64)