ejkernel.modules.operations.native_sparse_attention#
Native sparse attention module with automatic optimization.
This module implements native sparse attention using explicit block indices to define sparsity patterns. Unlike block-sparse attention which uses mask builders, this implementation directly specifies which blocks to attend to via index arrays.
- This approach is particularly efficient when:
The sparse pattern is known ahead of time
Block indices can be precomputed and reused
Fine-grained control over sparsity is needed
The sparse pattern is defined by block_indices and block_counts arrays, allowing flexible sparse attention patterns like local windows, strided patterns, or custom document-structure-aware sparsity.
- class ejkernel.modules.operations.native_sparse_attention.NativeSparseAttention[source]#
Bases:
Kernel[NativeSparseAttentionConfig,Array]Native Sparse Attention with custom optimization logic.
Implements sparse attention using explicit block index specification. This provides direct control over which blocks participate in attention computation, enabling efficient sparse patterns without runtime mask building.
- Features:
Direct block index specification for sparsity
Configurable block size and block counts
Support for variable-length sequences
Token-level sparse patterns via token_indices
Multiple platform support (Triton/Pallas/CUDA/XLA)
- The sparsity is controlled by:
block_indices: Which blocks each query block attends to
block_counts: Number of key blocks per query block
token_indices: Fine-grained token-level sparsity (optional)
- candidate_cfgs(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates a basic set of candidates for platform-agnostic tuning. Sparse attention benefits from consistent block sizes across dimensions.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List with single configuration using block_size=64 as baseline
- candidate_cfgs_gpu(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#
Generate GPU-optimized candidate configurations for autotuning.
Creates configurations tailored for GPU execution with Triton backend. Tests various block sizes (32, 64, 128) and warp counts (4, 8) to find optimal configuration for the specific GPU architecture.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of GPU-specific configurations with varying block sizes and warps
- candidate_cfgs_shard_map_gpu(inv: Invocation[NativeSparseAttentionConfig, Array])#
Generate GPU-optimized candidate configurations for autotuning.
Creates configurations tailored for GPU execution with Triton backend. Tests various block sizes (32, 64, 128) and warp counts (4, 8) to find optimal configuration for the specific GPU architecture.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of GPU-specific configurations with varying block sizes and warps
- candidate_cfgs_shard_map_tpu(inv: Invocation[NativeSparseAttentionConfig, Array])#
Generate TPU-optimized candidate configurations for autotuning.
Creates configurations tailored for TPU execution with Pallas backend. TPUs prefer larger block sizes (64, 128) for better vectorization.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of TPU-specific configurations optimized for matrix units
- candidate_cfgs_shard_map_xla(inv: Invocation[NativeSparseAttentionConfig, Array])#
Generate XLA-optimized candidate configurations for autotuning.
Creates configurations for XLA compiler backend. XLA handles optimization internally, so we provide conservative block sizes that work well across different hardware targets.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of XLA-compatible configurations with standard block sizes
- candidate_cfgs_tpu(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#
Generate TPU-optimized candidate configurations for autotuning.
Creates configurations tailored for TPU execution with Pallas backend. TPUs prefer larger block sizes (64, 128) for better vectorization.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of TPU-specific configurations optimized for matrix units
- candidate_cfgs_xla(inv: Invocation[NativeSparseAttentionConfig, Array])[source]#
Generate XLA-optimized candidate configurations for autotuning.
Creates configurations for XLA compiler backend. XLA handles optimization internally, so we provide conservative block sizes that work well across different hardware targets.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of XLA-compatible configurations with standard block sizes
- get_impl(cfg: NativeSparseAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for native sparse attention
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[NativeSparseAttentionConfig, Array]) NativeSparseAttentionConfig[source]#
Provide default configuration with block sizes.
Selects balanced block sizes that work well for typical sparse patterns. The default configuration uses uniform block sizes for simplicity.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration with block_size=64 for balanced performance
- run(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, softmax_scale: float | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: NativeSparseAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'][source]#
Execute native sparse attention with explicit block indices.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
block_indices – Indices of key blocks to attend to for each query block [batch, num_kv_heads, num_query_blocks, num_keys_blocks]
block_counts – Number of key blocks per query block (can be int or array)
softmax_scale – Optional scaling factor for attention scores
cu_seqlens – Cumulative sequence lengths for variable-length sequences
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Sparse attention output [batch, seq_len, num_heads, head_dim]
Note
When block_indices is None, a default pattern may be used depending on the implementation. Providing explicit indices gives full control over the sparsity pattern.
- ejkernel.modules.operations.native_sparse_attention.native_sparse_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g_cmp: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, g_slc: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads'] | None = None, block_indices: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads num_selected_blocks'] | None = None, block_counts: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len num_kv_heads'] | int = 16, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.NativeSparseAttentionConfig | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'][source]#
Execute native sparse attention with automatic optimization.
Sparse attention computes attention only on specified blocks or patterns, reducing computational cost for long sequences.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
block_indices – Indices of blocks to attend to
block_counts – Number of blocks per query block (default: 16)
softmax_scale – Scaling factor for attention
cu_seqlens – Cumulative sequence lengths for variable-length sequences
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Attention output with same shape as query
Example
>>> >>> out = native_sparse_attention(query, key, value) >>> >>> >>> out = native_sparse_attention(query, key, value, block_counts=32) >>> >>> >>> out = native_sparse_attention(query, key, value, platform="triton")