ejkernel.kernels._xla.native_sparse_attention._interface

ejkernel.kernels._xla.native_sparse_attention._interface#

Native Sparse Attention interface with block selection and compression.

This module provides the public API for Native Sparse Attention (NSA) that combines compressed global attention with fine-grained sparse block selection. Includes VJP for training support and top-k block selection algorithms.

ejkernel.kernels._xla.native_sparse_attention._interface.apply_native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], block_indices: Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'], block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, token_indices: jaxtyping.Int[jaxlib._jax.Array, 'total_tokens'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#

Applies block-sparse attention using a pre-computed sparsity pattern with JAX/XLA.

This function implements sparse attention where each query block attends to a subset of key blocks specified by the sparsity pattern. This reduces computational complexity from O(N²) to O(N·S) where S is the sparsity (number of blocks attended).

Parameters
  • query – Query tensor of shape (batch, seq_len, num_heads, head_dim).

  • key – Key tensor of shape (batch, seq_len, num_heads, head_dim).

  • value – Value tensor of shape (batch, seq_len, num_heads, head_dim).

  • block_indices – A tensor of shape (batch, num_heads, num_query_blocks, num_key_blocks) specifying which key blocks each query block should attend to. Each entry contains the index of a key block.

  • block_counts – Number of key blocks each query block attends to. Can be: - int: uniform sparsity for all query blocks - tensor [batch, num_heads, num_query_blocks]: per-block sparsity

  • block_size – Size of each block (both query and key blocks).

  • softmax_scale – Attention scaling factor. If None, defaults to 1/sqrt(head_dim).

Returns

Attention output of shape (batch, seq_len, num_heads, head_dim).

Notes

  • The sequence is divided into blocks of size block_size

  • Each query block computes attention over selected key blocks only

  • Sparsity is determined by block_indices and block_counts

  • Useful for long-range attention with reduced computation

Examples

>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> block_size = 64
>>> num_blocks = seq_len // block_size
>>>
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> block_counts = 4
>>> block_indices = jnp.tile(
...     jnp.arange(4)[None, None, None, :],
...     (batch, num_heads, num_blocks, 1)
... )
>>>
>>> output = apply_native_sparse_attention(
...     query, key, value, block_indices, block_counts, block_size
... )
>>> output.shape
(2, 1024, 8, 64)
>>>
>>> def create_local_pattern(num_blocks, window=2):
...     indices = []
...     for i in range(num_blocks):
...         local = list(range(max(0, i-window), min(num_blocks, i+window+1)))
...
...         local = local + [0] * (window*2+1 - len(local))
...         indices.append(local)
...     return jnp.array(indices)
>>>
>>> local_indices = create_local_pattern(num_blocks, window=2)
>>> local_indices = jnp.tile(local_indices[None, None, :, :], (batch, num_heads, 1, 1))
>>> output = apply_native_sparse_attention(
...     query, key, value, local_indices, block_counts=5, block_size=block_size
... )
ejkernel.kernels._xla.native_sparse_attention._interface.native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'][source]#

Native Sparse Attention (NSA) with XLA/JAX implementation.

NSA is a sparse attention mechanism that combines two components: 1. Compressed Attention: A coarse-grained attention over mean-pooled

(compressed) key-value blocks. This provides a global context summary.

  1. Selected Attention: A fine-grained, sparse attention where each query attends to a small subset of the original key-value blocks.

The key idea is that the selection of blocks for the second component can be determined efficiently using the compressed representations from the first. The final output is a gated combination of these two components.

Parameters
  • query – Query tensor of shape (batch_size, sequence, num_heads, head_dim).

  • key – Key tensor of shape (batch_size, sequence, num_heads, head_dim).

  • value – Value tensor of shape (batch_size, sequence, num_heads, head_dim).

  • g_cmp – Optional gate tensor for compressed attention, shape (batch_size, sequence, hidden_dim). If provided, the compressed attention component is computed.

  • g_slc – Optional gate tensor for selected attention, shape (batch_size, sequence, hidden_dim).

  • block_indices – Optional tensor of pre-computed block indices for selected attention, shape (batch_size, num_heads, num_query_blocks, block_counts). If g_cmp is provided, this argument is ignored, and block indices are computed dynamically via top-k selection over the compressed keys. If g_cmp is NOT provided, this argument is required.

  • block_counts – Number of blocks to select for each query. Can be: - int: uniform sparsity for all query blocks - tensor [batch, num_heads, num_query_blocks]: per-block sparsity Defaults to 16.

  • block_size – The size of each attention block. Defaults to 64.

  • softmax_scale – Scale factor for attention scores. Defaults to 1 / sqrt(head_dim).

  • cu_seqlens – Cumulative sequence lengths of shape (N+1) for variable-length training. If provided, batch size must be 1. Note: Variable-length sequences are not yet fully supported in XLA version.

Returns

The output tensor of shape (batch_size, sequence, num_heads, head_dim).

Notes

  • The XLA implementation uses pure JAX operations without custom kernels

  • For variable-length sequences (cu_seqlens), this uses the mean_pooling function

  • The compressed attention component uses mean-pooled key/value blocks

  • Top-k block selection is based on attention scores from compressed keys

Examples

>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> block_size = 64
>>> block_counts = 16
>>>
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> g_cmp = jnp.ones((batch, seq_len, num_heads * head_dim))
>>> output = native_sparse_attention(
...     query, key, value, g_cmp=g_cmp, block_counts=block_counts, block_size=block_size
... )
>>> output.shape
(2, 1024, 8, 64)
>>>
>>>
>>> num_blocks = seq_len // block_size
>>> block_indices = jnp.tile(
...     jnp.arange(block_counts)[None, None, None, :],
...     (batch, num_heads, num_blocks, 1)
... )
>>> output = native_sparse_attention(
...     query, key, value, block_indices=block_indices, block_counts=block_counts, block_size=block_size
... )
>>> output.shape
(2, 1024, 8, 64)