ejkernel.kernels._triton.native_sparse_attention._compression

ejkernel.kernels._triton.native_sparse_attention._compression#

Compressed attention computation for Native Sparse Attention (NSA).

This module implements the compressed attention component of Native Sparse Attention, where queries attend to mean-pooled (compressed) key-value blocks rather than individual tokens. This provides a coarse-grained global context that guides the selection of important blocks for fine-grained attention.

Compressed Attention Process:#

  1. Keys and values are mean-pooled into blocks (e.g., 64 tokens -> 1 block)

  2. Each query computes attention over these compressed representations

  3. The resulting attention scores indicate block importance

  4. Scores are used to select top-K blocks for detailed attention

  5. The compressed attention output can also be used directly (gated)

The implementation uses custom Triton kernels for efficient GPU computation and supports both forward and backward passes with full autodifferentiation.

Key benefits: - O(N²/B) complexity for block size B (much faster than O(N²)) - Provides global context across entire sequence - Enables learned sparse attention patterns - Integrates with variable-length sequence processing

The compressed attention output has two uses in NSA: 1. As attention scores for block selection (via top-K) 2. As a direct output pathway (gated with g_cmp)

Functions: - nsa_compression: Main user-facing function with custom VJP - nsa_compression_fwd: Forward pass kernel wrapper - nsa_compression_bwd: Backward pass kernel wrapper - nsa_compression_fwd_kernel: Triton kernel for forward pass - nsa_compression_bwd_kernel_dq: Triton kernel for query gradients - nsa_compression_bwd_kernel_dkv: Triton kernel for key/value gradients

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.native_sparse_attention._compression import nsa_compression
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> block_size = 64
>>>
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>> k_compressed = jnp.ones((batch, 16, num_heads, head_dim))
>>> v_compressed = jnp.ones((batch, 16, num_heads, head_dim))
>>>
>>> output, lse = nsa_compression(
...     q, k_compressed, v_compressed,
...     block_size=block_size,
...     softmax_scale=head_dim ** -0.5
... )
ejkernel.kernels._triton.native_sparse_attention._compression.nsa_compression(query: Array, key: Array, value: Array, block_size: int, softmax_scale: float, cu_seqlens: jax.jaxlib._jax.Array | None = None, token_indices: jax.jaxlib._jax.Array | None = None) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#

Compute compressed attention over mean-pooled key-value blocks.

This function implements the compressed attention pathway of Native Sparse Attention, where each query token attends to compressed (mean-pooled) representations of key-value blocks. This provides O(N²/B) complexity while maintaining global context across the sequence.

The compressed attention serves two purposes: 1. Block selection: Attention scores indicate which blocks are important 2. Direct output: Can be used as a coarse-grained attention output (gated)

Parameters
  • query – Query tensor of shape (batch, seq_len, num_heads, head_dim). Each query attends to all compressed KV blocks.

  • key – Compressed key tensor of shape (batch, num_blocks, num_heads, head_dim). Keys have been mean-pooled from blocks of size block_size.

  • value – Compressed value tensor of shape (batch, num_blocks, num_heads, head_dim). Values have been mean-pooled from blocks of size block_size.

  • block_size – Size of each block in tokens. Keys/values should already be compressed at this granularity.

  • softmax_scale – Attention score scaling factor, typically 1/sqrt(head_dim).

  • cu_seqlens – Optional cumulative sequence lengths for variable-length sequences, shape (num_seqs + 1,). If provided, enables packed variable-length processing.

  • token_indices – Optional token indices for variable-length sequences, shape (total_tokens, 2). Each row contains [sequence_id, token_offset].

Returns

  • output: Compressed attention output, shape (batch, seq_len, num_heads, head_dim).

  • lse: Log-sum-exp of attention scores, shape (batch, seq_len, num_heads). Used for numerical stability and block selection.

Return type

A tuple containing

Note

The key and value tensors should already be mean-pooled to the block level. Use mean_pooling(k, block_size) and mean_pooling(v, block_size) to prepare them.

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.mean_pooling import mean_pooling
>>> from ejkernel.kernels._triton.native_sparse_attention._compression import nsa_compression
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> block_size = 64
>>>
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> k_compressed = mean_pooling(k, block_size)
>>> v_compressed = mean_pooling(v, block_size)
>>>
>>>
>>> output, lse = nsa_compression(
...     q, k_compressed, v_compressed,
...     block_size=block_size,
...     softmax_scale=head_dim ** -0.5
... )
>>> print(output.shape)
>>> print(lse.shape)
ejkernel.kernels._triton.native_sparse_attention._compression.nsa_compression_bwd(q: Array, k: Array, v: Array, o: Array, lse: Array, do: Array, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jax.jaxlib._jax.Array | None = None, token_indices: jax.jaxlib._jax.Array | None = None)[source]#
ejkernel.kernels._triton.native_sparse_attention._compression.nsa_compression_fwd(q: Array, k: Array, v: Array, block_size: int, softmax_scale: float, cu_seqlens: jax.jaxlib._jax.Array | None = None, token_indices: jax.jaxlib._jax.Array | None = None) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#