# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Native Sparse Attention (NSA) implementation using Triton kernels.
This module implements Native Sparse Attention, a hybrid attention mechanism
that combines compressed attention over coarse-grained blocks with selective
fine-grained attention to important tokens. This approach achieves significant
computational savings while maintaining model quality.
NSA Architecture:
-----------------
NSA consists of two parallel attention pathways that are gated and combined:
1. **Compressed Attention** (g_cmp pathway):
- Keys and values are mean-pooled into blocks (e.g., 64 tokens -> 1 block)
- Each query attends to these compressed block representations
- Provides global context with O(N²/B) complexity for block size B
- Used to identify which blocks contain important information
2. **Selected Attention** (g_slc pathway):
- Based on compressed attention scores, select top-K blocks per query
- Perform full fine-grained attention only to selected blocks
- Provides detailed local attention with O(N*K*B) complexity
- Focuses computational resources on relevant regions
The final output is a gated combination:
output = g_slc * selected_attn + g_cmp * compressed_attn
where g_slc and g_cmp are learned gating values that balance the two pathways.
Key Benefits:
-------------
1. **Adaptive sparsity**: Automatically learns which blocks are important
2. **Flexible trade-off**: Balance between global context and local detail
3. **Reduced computation**: O(N²/B + N*K*B) instead of O(N²)
4. **Maintained quality**: Selective attention focuses on relevant tokens
Example complexity for N=4096, B=64, K=16:
- Standard attention: 16M operations
- NSA: ~1M operations (16x reduction)
Implementation Details:
-----------------------
- Requires Grouped Query Attention (GQA) with group size multiple of 16
- Supports variable-length sequences via cu_seqlens
- Block indices can be pre-computed or learned via compression pathway
- Gradients flow through both compression and selection mechanisms
Example:
>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.native_sparse_attention import native_sparse_attention
>>>
>>> batch, seq_len, num_q_heads, num_kv_heads, head_dim = 2, 2048, 32, 32, 64
>>> q = jnp.ones((batch, seq_len, num_q_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_kv_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_kv_heads, head_dim))
>>>
>>>
>>> g_cmp = jnp.ones((batch, seq_len, num_q_heads))
>>> g_slc = jnp.ones((batch, seq_len, num_q_heads))
>>>
>>>
>>> output = native_sparse_attention(
... q, k, v,
... g_cmp=g_cmp,
... g_slc=g_slc,
... block_counts=16,
... block_size=64
... )
Reference:
Efficient Attention via Control Variates
https://arxiv.org/abs/2302.04542
"""
import warnings
from functools import partial
import jax
import jaxtyping
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import Array, Float, Int
from ejkernel.xla_utils.utils import prepare_token_indices
from ..._registry import Backend, Platform, kernel_registry
from ..mean_pooling import mean_pooling
from ._compression import nsa_compression
from ._triton_impl_bwd import bwd_triton_impl
from ._triton_impl_fwd import fwd_triton_impl, nsa_topk
def _fwd_call(
query: Float[Array, "batch seq_len num_heads head_dim"],
key: Float[Array, "batch seq_len num_heads head_dim"],
value: Float[Array, "batch seq_len num_heads head_dim"],
block_indices: Int[Array, "batch seq_len num_kv_heads num_selected_blocks"],
block_counts: Int[Array, "batch seq_len num_kv_heads"] | int,
block_size: int,
softmax_scale: float,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
token_indices: Int[Array, "total_tokens"] | None = None,
) -> tuple[
Float[Array, "batch seq_len num_heads head_dim"],
tuple[Float[Array, "..."], ...],
]:
"""
Forward pass for NSA in a custom VJP.
Args:
query: Query tensor.
key: Key tensor.
value: Value tensor.
block_indices: Sparsity pattern indicating which blocks to attend to.
block_counts: Number of blocks to attend to per query.
block_size: Size of each block.
softmax_scale: Attention scaling factor.
cu_seqlens: Cumulative sequence lengths for variable-length sequences.
token_indices: Token indices for variable-length sequences.
Returns:
A tuple containing the attention output and residuals for the backward pass.
"""
o, lse = fwd_triton_impl(
q=query,
k=key,
v=value,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
token_indices=token_indices,
)
residual = query, key, value, o, lse
return o, residual
def _bwd_call(
block_indices: Int[Array, "batch seq_len num_kv_heads num_selected_blocks"],
block_counts: Int[Array, "batch seq_len num_kv_heads"] | int,
block_size: int,
softmax_scale: float,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None,
token_indices: Int[Array, "total_tokens"] | None,
residual: tuple[Float[Array, "..."], ...],
do: Float[Array, "batch seq_len num_heads head_dim"],
):
"""
Backward pass for NSA in a custom VJP.
Args:
block_indices: Sparsity pattern used in the forward pass.
block_counts: Number of blocks attended to per query.
block_size: Size of each block.
softmax_scale: Attention scaling factor.
cu_seqlens: Cumulative sequence lengths for variable-length sequences.
token_indices: Token indices for variable-length sequences.
residual: Tensors saved from the forward pass.
do: Gradient of the output tensor.
Returns:
A tuple of gradients (dq, dk, dv).
"""
query, key, value, o, lse = residual
dq, dk, dv = bwd_triton_impl(
q=query,
k=key,
v=value,
o=o,
lse=lse,
do=do,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
token_indices=token_indices,
)
return dq, dk, dv
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
@partial(jax.jit, static_argnums=(5, 6))
def _apply_nsa(
query: Float[Array, "batch seq_len num_heads head_dim"],
key: Float[Array, "batch seq_len num_heads head_dim"],
value: Float[Array, "batch seq_len num_heads head_dim"],
block_indices: Int[Array, "batch seq_len num_kv_heads num_selected_blocks"],
block_counts: Int[Array, "batch seq_len num_kv_heads"] | int,
block_size: int,
softmax_scale: float,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
token_indices: Int[Array, "total_tokens"] | None = None,
) -> Float[Array, "batch seq_len num_heads head_dim"]:
"""
Core JIT-compiled NSA function with a custom VJP.
This internal function applies the sparse attention pattern defined by
`block_indices` and has a custom gradient definition for memory efficiency.
Args:
query: Query tensor.
key: Key tensor.
value: Value tensor.
block_indices: Sparsity pattern indicating which blocks to attend to.
block_counts: Number of blocks to attend to per query.
block_size: Size of each block (static argument).
softmax_scale: Attention scaling factor (static argument).
cu_seqlens: Cumulative sequence lengths for variable-length sequences.
token_indices: Token indices for variable-length sequences.
Returns:
The sparse attention output tensor.
"""
return _fwd_call(
query=query,
key=key,
value=value,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
token_indices=token_indices,
)[0]
_apply_nsa.defvjp(_fwd_call, _bwd_call)
[docs]@kernel_registry.register("apply_native_sparse_attention", Platform.TRITON, Backend.GPU)
@jaxtyping.jaxtyped(typechecker=beartype)
def apply_native_sparse_attention(
query: Float[Array, "batch seq_len num_q_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
block_indices: Int[Array, "batch seq_len num_kv_heads num_selected_blocks"],
block_counts: Int[Array, "batch seq_len num_kv_heads"] | int = 16,
block_size: int = 64,
softmax_scale: float | None = None,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
token_indices: Int[Array, "total_tokens"] | None = None,
) -> Float[Array, "batch seq_len num_q_heads head_dim"]:
"""
Applies NativeSparseAttention using a pre-computed sparse block pattern.
This function is a user-facing wrapper around the core JIT-compiled
`_apply_nsa` function. It optionally prepares token indices for
variable-length sequence processing.
Args:
query: Query tensor.
key: Key tensor.
value: Value tensor.
block_indices: A tensor specifying the indices of the key/value blocks
that each query should attend to.
block_counts: The number of blocks each query attends to. Can be an
integer (for uniform sparsity) or a tensor.
block_size: The size of each key/value block.
softmax_scale: The scaling factor for the attention scores.
cu_seqlens: Optional cumulative sequence lengths for variable-length
sequences.
token_indices: Optional pre-computed token indices for variable-length
sequences. If `None` and `cu_seqlens` is provided, they are computed
internally.
Returns:
The output tensor from the sparse attention computation.
"""
if softmax_scale is None:
softmax_scale = 1.0 / (query.shape[-1] ** 0.5)
if token_indices is None and cu_seqlens is not None:
token_indices = prepare_token_indices(cu_seqlens)
return _apply_nsa(
query=query,
key=key,
value=value,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
token_indices=token_indices,
)
[docs]@kernel_registry.register("native_sparse_attention", Platform.TRITON, Backend.GPU)
@jaxtyping.jaxtyped(typechecker=beartype)
def native_sparse_attention(
query: Float[Array, "batch seq_len num_q_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
g_cmp: Float[Array, "batch seq_len num_q_heads"] | None = None,
g_slc: Float[Array, "batch seq_len num_q_heads"] | None = None,
block_indices: Int[Array, "batch seq_len num_kv_heads num_selected_blocks"] | None = None,
block_counts: Int[Array, "batch seq_len num_kv_heads"] | int = 16,
block_size: int = 64,
softmax_scale: float | None = None,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> Float[Array, "batch seq_len num_q_heads head_dim"]:
"""
NSA is a sparse attention mechanism that combines two components:
1. **Compressed Attention**: A coarse-grained attention over mean-pooled
(compressed) key-value blocks. This provides a global context summary.
2. **Selected Attention**: A fine-grained, sparse attention where each
query attends to a small subset of the original key-value blocks.
The key idea is that the selection of blocks for the second component can be
determined efficiently using the compressed representations from the first.
The final output is a gated combination of these two components.
Args:
query: Query tensor of shape `(batch_size, sequence, query_heads, dimk)`.
key: Key tensor of shape `(batch_size, sequence, kvheads, dimk)`. GQA is enforced, where the ratio
of query heads (query_heads) to key/value heads (kvheads) must be a multiple of 16.
value: Value tensor of shape `(batch_size, sequence, kvheads, dimv)`.
g_cmp: Optional gate tensor for compressed attention, shape `(batch_size, sequence, query_heads)`.
If provided, the compressed attention component is computed.
g_slc: Optional gate tensor for selected attention, shape `(batch_size, sequence, query_heads)`.
block_indices: Optional tensor of pre-computed block indices for selected
attention, shape `(batch_size, sequence, kvheads, S)`. `S` is the number of selected
blocks (`block_counts`). If `g_cmp` is provided, this argument is
ignored, and block indices are computed dynamically via top-k
selection over the compressed keys. If `g_cmp` is NOT provided, this
argument is required.
block_counts: Number of blocks to select for each query. Defaults to 16.
block_size: The size of each attention block. Defaults to 64.
softmax_scale: Scale factor for attention scores. Defaults to `1 / sqrt(dimk)` or `dimk**-0.5`.
cu_seqlens: Cumulative sequence lengths of shape `(N+1)` for
variable-length training. If provided, batch size batch_size must be 1.
Returns:
The output tensor of shape `(batch_size, sequence, query_heads, dimv)`.
"""
assert block_counts is not None, "block counts must be provided for selection"
if softmax_scale is None:
softmax_scale = key.shape[-1] ** -0.5
if cu_seqlens is not None:
assert query.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
group_size = query.shape[2] // key.shape[2]
assert group_size % 16 == 0, f"Group size must be a multiple of 16 in NSA, got {group_size}"
k_cmp, v_cmp = mean_pooling(key, block_size, cu_seqlens), mean_pooling(value, block_size, cu_seqlens)
o_cmp = None
if g_cmp is not None:
o_cmp, lse_cmp = nsa_compression(
query=query,
key=k_cmp,
value=v_cmp,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
)
if block_indices is not None:
warnings.warn("`block_indices` will be ignored when `g_cmp` is provided", stacklevel=1)
block_indices = nsa_topk(
q=query,
k=k_cmp,
lse=lse_cmp,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
)
assert block_indices is not None, "if `g_cmp` is not passed, `block_indices` must be provided."
o_slc = apply_native_sparse_attention(
query=query,
key=key,
value=value,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
softmax_scale=softmax_scale,
cu_seqlens=cu_seqlens,
)
o = o_slc
if g_slc is not None:
o = o_slc * jnp.expand_dims(g_slc, -1)
if o_cmp is not None and g_cmp is not None:
o = o + o_cmp * jnp.expand_dims(g_cmp, -1)
return o