ejkernel.modules.operations.ragged_decode_attention#
Ragged Decode Attention module with automatic optimization.
This module implements ragged decode attention, an efficient attention mechanism optimized for inference scenarios with variable-length sequences in the decode phase. Unlike standard attention which requires padded sequences, ragged attention processes sequences with different lengths efficiently by using sequence start/end markers.
- Ragged decode attention is particularly valuable for:
Inference workloads with batched sequences of varying lengths
Decoder-only models during generation
Serving scenarios requiring efficient batching
Situations where padding overhead is significant
The key innovation is using sequence_start and sequence_end arrays to define valid attention ranges per sequence, eliminating the need for padding while maintaining efficient vectorized computation.
- Key Features:
Efficient variable-length sequence handling without padding
Support for sliding window attention for long contexts
Optional logit soft capping for numerical stability
Attention sink support for improved long-context performance
Configurable block sizes for memory-compute tradeoffs
- Mathematical Foundation:
- For each query position i in sequence s:
output[i] = softmax(Q[i] @ K[start[s]:end[s]].T / scale) @ V[start[s]:end[s]]
Where start[s] and end[s] define the valid KV range for sequence s.
- class ejkernel.modules.operations.ragged_decode_attention.RaggedDecodeAttention[source]#
Bases:
Kernel[RaggedDecodeAttentionConfig,Array]Ragged Decode Attention with custom optimization logic.
Implements efficient attention for variable-length sequences during inference decode phase. Uses sequence start/end markers to define valid attention ranges without padding overhead.
- Features:
Zero-padding overhead for variable-length sequences
Sliding window attention for local context
Logit soft capping for numerical stability
Attention sink mechanism for long contexts
Multiple platform support (Triton/Pallas/CUDA/XLA)
Configurable block sizes for performance tuning
- This implementation is particularly efficient for:
Batch inference with varying prompt/generation lengths
Serving workloads requiring dynamic batching
Decoder-only models in generation mode
- candidate_cfgs(inv: Invocation[RaggedDecodeAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates multiple configurations optimized for different decode scenarios, from small batches with short contexts to larger batches with longer contexts.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Decode attention typically has small query dimensions (batch size), so candidates focus on optimizing block sizes.
- candidate_cfgs_gpu(inv: Invocation[RaggedDecodeAttentionConfig, Array])[source]#
GPU/Triton candidates for ragged decode attention (bigger blocks + higher warps).
Explores kv_blocksize up to 256 (when split_len allows)
Tries blocksize_heads in {4, 8, 16} if grouped-heads permit
Warps up to 8 (depending on kv_block/head_dim)
Stages in {1, 2, 3} (kept low; smem-guarded)
Prefers split_len near {128, 256, 512}; ensures split_len % kv_blocksize == 0
- create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'batch num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper for distributed execution.
Creates a wrapper function that applies shard_map to distribute the ragged decode attention computation across devices according to the provided sharding specifications.
- Parameters
query – Query tensor [batch, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
sequence_start – Start indices for valid KV range per sequence [batch]
sequence_end – End indices for valid KV range per sequence [batch]
softmax_scale – Scaling factor for attention scores
sliding_window – Optional (left, right) window sizes for local attention
logits_soft_cap – Optional soft cap to bound attention logits
softmax_aux – Optional attention sink logits
platform – Platform to use for execution
cfg – Configuration for the kernel
mesh – JAX mesh for distributed execution
in_specs – Partition specifications for input tensors
out_specs – Partition specifications for output tensor
check_vma – Whether to check for valid memory access patterns
- Returns
Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped function and call_args are the arguments to pass to it.
- get_impl(cfg: RaggedDecodeAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend preferences
- Returns
Callable kernel implementation for ragged decode attention
- Raises
ValueError – If no matching implementation is found for the configuration
- heuristic_cfg(inv: Invocation[RaggedDecodeAttentionConfig, Array]) RaggedDecodeAttentionConfig[source]#
Provide default configuration optimized for decode attention.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default KernelConfig with conservative block sizes suitable for typical decode scenarios (small query sizes, variable KV lengths)
- run(query: Float[jaxlib._jax.Array, 'batch num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: RaggedDecodeAttentionConfig) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute ragged decode attention with variable-length sequences.
Computes attention for batched queries where each sequence has a different valid key-value range defined by sequence_start and sequence_end markers.
- Parameters
query – Query tensor [batch, num_heads, head_dim] (typically single decode step)
key – Key tensor [batch, seq_len, num_kv_heads, head_dim] (full context)
value – Value tensor [batch, seq_len, num_kv_heads, head_dim] (full context)
sequence_start – Start indices for valid KV range per sequence [batch]
sequence_end – End indices (exclusive) for valid KV range per sequence [batch]
softmax_scale – Scaling factor for attention scores (default: 1.0)
sliding_window – Optional (left, right) window sizes for local attention
logits_soft_cap – Optional soft cap to bound attention logits
softmax_aux – Optional attention sink logits for improved long-context performance
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object containing block_size parameter
- Returns
Attention output [total_tokens, num_q_heads, head_dim]
Note
The sequence_start and sequence_end arrays define which KV positions are valid for each query. This enables efficient batching of sequences with different lengths without padding overhead.
Example
>>> >>> sequence_start = jnp.array([0, 50]) >>> sequence_end = jnp.array([50, 150]) >>> out = ragged_decode_attention(q, k, v, sequence_start, sequence_end)
- ejkernel.modules.operations.ragged_decode_attention.ragged_decode_attention(query: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim'] | jaxtyping.Float[jaxlib._jax.Array, 'batch 1 num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute ragged decode attention with automatic optimization.
Efficiently computes attention for variable-length sequences during the decode phase, using start/end indices to define valid attention ranges without padding overhead.
- Parameters
query – Query tensor [batch, num_heads, head_dim] for current decode step
key – Full key context [batch, seq_len, num_kv_heads, head_dim]
value – Full value context [batch, seq_len, num_kv_heads, head_dim]
sequence_start – Start index of valid KV range per sequence [batch]
sequence_end – End index (exclusive) of valid KV range per sequence [batch]
softmax_scale – Attention score scaling factor (default: 1.0)
sliding_window – Optional (left, right) window sizes for local attention
logits_soft_cap – Optional soft cap for attention logits (improves stability)
softmax_aux – Optional attention sink values for long-context handling
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional config override (block_size is set via cfg)
- Returns
Attention output [total_tokens, num_q_heads, head_dim]
Example
>>> >>> out = ragged_decode_attention(q, k, v, starts, ends) >>> >>> >>> from ejkernel.modules.operations.configs import RaggedDecodeAttentionConfig >>> cfg = RaggedDecodeAttentionConfig(block_size=128) >>> out = ragged_decode_attention( ... q, k, v, starts, ends, ... sliding_window=(256, 256), ... cfg=cfg ... ) >>> >>> >>> out = ragged_decode_attention( ... q, k, v, starts, ends, ... logits_soft_cap=50.0, ... softmax_scale=0.125 ... ) >>> >>> >>> out = ragged_decode_attention(..., platform="triton")
Note
This function is optimized for decode scenarios where query size is small (typically batch_size) and KV length varies per sequence. For prefill phase with large queries, consider using standard flash_attention instead.