ejkernel.modules.operations.ragged_decode_attention#

Ragged Decode Attention module with automatic optimization.

This module implements ragged decode attention, an efficient attention mechanism optimized for inference scenarios with variable-length sequences in the decode phase. Unlike standard attention which requires padded sequences, ragged attention processes sequences with different lengths efficiently by using sequence start/end markers.

Ragged decode attention is particularly valuable for:
  • Inference workloads with batched sequences of varying lengths

  • Decoder-only models during generation

  • Serving scenarios requiring efficient batching

  • Situations where padding overhead is significant

The key innovation is using sequence_start and sequence_end arrays to define valid attention ranges per sequence, eliminating the need for padding while maintaining efficient vectorized computation.

Key Features:
  • Efficient variable-length sequence handling without padding

  • Support for sliding window attention for long contexts

  • Optional logit soft capping for numerical stability

  • Attention sink support for improved long-context performance

  • Configurable block sizes for memory-compute tradeoffs

Mathematical Foundation:
For each query position i in sequence s:

output[i] = softmax(Q[i] @ K[start[s]:end[s]].T / scale) @ V[start[s]:end[s]]

Where start[s] and end[s] define the valid KV range for sequence s.

class ejkernel.modules.operations.ragged_decode_attention.RaggedDecodeAttention[source]#

Bases: Kernel[RaggedDecodeAttentionConfig, Array]

Ragged Decode Attention with custom optimization logic.

Implements efficient attention for variable-length sequences during inference decode phase. Uses sequence start/end markers to define valid attention ranges without padding overhead.

Features:
  • Zero-padding overhead for variable-length sequences

  • Sliding window attention for local context

  • Logit soft capping for numerical stability

  • Attention sink mechanism for long contexts

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

  • Configurable block sizes for performance tuning

This implementation is particularly efficient for:
  • Batch inference with varying prompt/generation lengths

  • Serving workloads requiring dynamic batching

  • Decoder-only models in generation mode

candidate_cfgs(inv: Invocation[RaggedDecodeAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates multiple configurations optimized for different decode scenarios, from small batches with short contexts to larger batches with longer contexts.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Decode attention typically has small query dimensions (batch size), so candidates focus on optimizing block sizes.

candidate_cfgs_gpu(inv: Invocation[RaggedDecodeAttentionConfig, Array])[source]#

GPU/Triton candidates for ragged decode attention (bigger blocks + higher warps).

  • Explores kv_blocksize up to 256 (when split_len allows)

  • Tries blocksize_heads in {4, 8, 16} if grouped-heads permit

  • Warps up to 8 (depending on kv_block/head_dim)

  • Stages in {1, 2, 3} (kept low; smem-guarded)

  • Prefers split_len near {128, 256, 512}; ensures split_len % kv_blocksize == 0

create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'batch num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper for distributed execution.

Creates a wrapper function that applies shard_map to distribute the ragged decode attention computation across devices according to the provided sharding specifications.

Parameters
  • query – Query tensor [batch, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • sequence_start – Start indices for valid KV range per sequence [batch]

  • sequence_end – End indices for valid KV range per sequence [batch]

  • softmax_scale – Scaling factor for attention scores

  • sliding_window – Optional (left, right) window sizes for local attention

  • logits_soft_cap – Optional soft cap to bound attention logits

  • softmax_aux – Optional attention sink logits

  • platform – Platform to use for execution

  • cfg – Configuration for the kernel

  • mesh – JAX mesh for distributed execution

  • in_specs – Partition specifications for input tensors

  • out_specs – Partition specifications for output tensor

  • check_vma – Whether to check for valid memory access patterns

Returns

Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped function and call_args are the arguments to pass to it.

get_impl(cfg: RaggedDecodeAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for ragged decode attention

Raises

ValueError – If no matching implementation is found for the configuration

heuristic_cfg(inv: Invocation[RaggedDecodeAttentionConfig, Array]) RaggedDecodeAttentionConfig[source]#

Provide default configuration optimized for decode attention.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default KernelConfig with conservative block sizes suitable for typical decode scenarios (small query sizes, variable KV lengths)

run(query: Float[jaxlib._jax.Array, 'batch num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: RaggedDecodeAttentionConfig) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute ragged decode attention with variable-length sequences.

Computes attention for batched queries where each sequence has a different valid key-value range defined by sequence_start and sequence_end markers.

Parameters
  • query – Query tensor [batch, num_heads, head_dim] (typically single decode step)

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim] (full context)

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim] (full context)

  • sequence_start – Start indices for valid KV range per sequence [batch]

  • sequence_end – End indices (exclusive) for valid KV range per sequence [batch]

  • softmax_scale – Scaling factor for attention scores (default: 1.0)

  • sliding_window – Optional (left, right) window sizes for local attention

  • logits_soft_cap – Optional soft cap to bound attention logits

  • softmax_aux – Optional attention sink logits for improved long-context performance

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object containing block_size parameter

Returns

Attention output [total_tokens, num_q_heads, head_dim]

Note

The sequence_start and sequence_end arrays define which KV positions are valid for each query. This enables efficient batching of sequences with different lengths without padding overhead.

Example

>>>
>>> sequence_start = jnp.array([0, 50])
>>> sequence_end = jnp.array([50, 150])
>>> out = ragged_decode_attention(q, k, v, sequence_start, sequence_end)
ejkernel.modules.operations.ragged_decode_attention.ragged_decode_attention(query: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim'] | jaxtyping.Float[jaxlib._jax.Array, 'batch 1 num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute ragged decode attention with automatic optimization.

Efficiently computes attention for variable-length sequences during the decode phase, using start/end indices to define valid attention ranges without padding overhead.

Parameters
  • query – Query tensor [batch, num_heads, head_dim] for current decode step

  • key – Full key context [batch, seq_len, num_kv_heads, head_dim]

  • value – Full value context [batch, seq_len, num_kv_heads, head_dim]

  • sequence_start – Start index of valid KV range per sequence [batch]

  • sequence_end – End index (exclusive) of valid KV range per sequence [batch]

  • softmax_scale – Attention score scaling factor (default: 1.0)

  • sliding_window – Optional (left, right) window sizes for local attention

  • logits_soft_cap – Optional soft cap for attention logits (improves stability)

  • softmax_aux – Optional attention sink values for long-context handling

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional config override (block_size is set via cfg)

Returns

Attention output [total_tokens, num_q_heads, head_dim]

Example

>>>
>>> out = ragged_decode_attention(q, k, v, starts, ends)
>>>
>>>
>>> from ejkernel.modules.operations.configs import RaggedDecodeAttentionConfig
>>> cfg = RaggedDecodeAttentionConfig(block_size=128)
>>> out = ragged_decode_attention(
...     q, k, v, starts, ends,
...     sliding_window=(256, 256),
...     cfg=cfg
... )
>>>
>>>
>>> out = ragged_decode_attention(
...     q, k, v, starts, ends,
...     logits_soft_cap=50.0,
...     softmax_scale=0.125
... )
>>>
>>>
>>> out = ragged_decode_attention(..., platform="triton")

Note

This function is optimized for decode scenarios where query size is small (typically batch_size) and KV length varies per sequence. For prefill phase with large queries, consider using standard flash_attention instead.