ejkernel.modules.operations.native_sparse_attention#

Native sparse attention module with automatic optimization.

This module implements native sparse attention using explicit block indices to define sparsity patterns. Unlike block-sparse attention which uses mask builders, this implementation directly specifies which blocks to attend to via index arrays.

This approach is particularly efficient when:
  • The sparse pattern is known ahead of time

  • Block indices can be precomputed and reused

  • Fine-grained control over sparsity is needed

The sparse pattern is defined by block_indices and block_counts arrays, allowing flexible sparse attention patterns like local windows, strided patterns, or custom document-structure-aware sparsity.

class ejkernel.modules.operations.native_sparse_attention.NativeSparseAttention[source]#

Bases: Kernel[NativeSparseAttentionConfig, Array]

Native Sparse Attention with custom optimization logic.

Implements sparse attention using explicit block index specification. This provides direct control over which blocks participate in attention computation, enabling efficient sparse patterns without runtime mask building.

Features:
  • Direct block index specification for sparsity

  • Configurable block size and block counts

  • Support for variable-length sequences

  • Token-level sparse patterns via token_indices

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

The sparsity is controlled by:
  • block_indices: Which blocks each query block attends to

  • block_counts: Number of key blocks per query block

  • token_indices: Fine-grained token-level sparsity (optional)

candidate_cfgs(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates a basic set of candidates for platform-agnostic tuning. Sparse attention benefits from consistent block sizes across dimensions.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List with single configuration using block_size=64 as baseline

candidate_cfgs_gpu(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#

Generate GPU-optimized candidate configurations for autotuning.

Creates configurations tailored for GPU execution with Triton backend. Tests various block sizes (32, 64, 128) and warp counts (4, 8) to find optimal configuration for the specific GPU architecture.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of GPU-specific configurations with varying block sizes and warps

candidate_cfgs_shard_map_gpu(inv: Invocation[NativeSparseAttentionConfig, Array])#

Generate GPU-optimized candidate configurations for autotuning.

Creates configurations tailored for GPU execution with Triton backend. Tests various block sizes (32, 64, 128) and warp counts (4, 8) to find optimal configuration for the specific GPU architecture.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of GPU-specific configurations with varying block sizes and warps

candidate_cfgs_shard_map_tpu(inv: Invocation[NativeSparseAttentionConfig, Array])#

Generate TPU-optimized candidate configurations for autotuning.

Creates configurations tailored for TPU execution with Pallas backend. TPUs prefer larger block sizes (64, 128) for better vectorization.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of TPU-specific configurations optimized for matrix units

candidate_cfgs_shard_map_xla(inv: Invocation[NativeSparseAttentionConfig, Array])#

Generate XLA-optimized candidate configurations for autotuning.

Creates configurations for XLA compiler backend. XLA handles optimization internally, so we provide conservative block sizes that work well across different hardware targets.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of XLA-compatible configurations with standard block sizes

candidate_cfgs_tpu(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#

Generate TPU-optimized candidate configurations for autotuning.

Creates configurations tailored for TPU execution with Pallas backend. TPUs prefer larger block sizes (64, 128) for better vectorization.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of TPU-specific configurations optimized for matrix units

candidate_cfgs_xla(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#

Generate XLA-optimized candidate configurations for autotuning.

Creates configurations for XLA compiler backend. XLA handles optimization internally, so we provide conservative block sizes that work well across different hardware targets.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of XLA-compatible configurations with standard block sizes

get_impl(cfg: NativeSparseAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for native sparse attention

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[NativeSparseAttentionConfig, Array]) NativeSparseAttentionConfig[source]#

Provide default configuration with block sizes.

Selects balanced block sizes that work well for typical sparse patterns. The default configuration uses uniform block sizes for simplicity.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with block_size=64 for balanced performance

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: NativeSparseAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'][source]#

Execute native sparse attention with explicit block indices.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • block_indices – Indices of key blocks to attend to for each query block [batch, num_kv_heads, num_query_blocks, num_keys_blocks]

  • block_counts – Number of key blocks per query block (can be int or array)

  • softmax_scale – Optional scaling factor for attention scores

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Sparse attention output [batch, seq_len, num_heads, head_dim]

Note

When block_indices is None, a default pattern may be used depending on the implementation. Providing explicit indices gives full control over the sparsity pattern.

ejkernel.modules.operations.native_sparse_attention.native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.NativeSparseAttentionConfig | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'][source]#

Execute native sparse attention with automatic optimization.

Sparse attention computes attention only on specified blocks or patterns, reducing computational cost for long sequences.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • block_indices – Indices of blocks to attend to

  • block_counts – Number of blocks per query block (default: 16)

  • softmax_scale – Scaling factor for attention

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Attention output with same shape as query

Example

>>>
>>> out = native_sparse_attention(query, key, value)
>>>
>>>
>>> out = native_sparse_attention(query, key, value, block_counts=32)
>>>
>>>
>>> out = native_sparse_attention(query, key, value, platform="triton")