ejkernel.modules.operations.ragged_page_attention_v2#
Ragged Page Attention module with automatic optimization.
This module implements ragged page attention, combining the benefits of both ragged (variable-length) sequence processing and paged KV cache management. This approach is particularly efficient for serving scenarios where sequences have variable lengths and KV cache is organized in fixed-size pages.
- Ragged page attention addresses key challenges in LLM inference:
Variable-length sequences without padding overhead
Efficient memory management through paged KV cache
Dynamic batching with different sequence lengths
Memory sharing for beam search and prefix caching
- Key Concepts:
- Ragged Layout: Sequences are concatenated without padding, with start
locations tracking where each sequence begins
Pages: Fixed-size blocks holding portions of KV cache Block Tables: Mapping from logical sequence positions to physical pages
- The combination provides:
Zero padding overhead (ragged layout)
Flexible memory allocation (paged cache)
Efficient batching of variable-length sequences
Support for dynamic sequence management
- Memory Layout:
Queries: [total_tokens, num_heads, head_dim] (ragged, no padding) KV Cache: [num_pages, page_size, num_heads, head_dim] (paged)
- Mathematical Foundation:
- For token i in sequence s:
start_idx = query_start_loc[s] end_idx = query_start_loc[s + 1] output[i] = attention(Q[start_idx:end_idx], K[pages[s]], V[pages[s]])
This is the most memory-efficient attention variant for serving workloads.
- class ejkernel.modules.operations.ragged_page_attention_v2.RaggedPageAttentionv2[source]#
Bases:
Kernel[RaggedPageAttentionv2Config,Array]Ragged Page Attention with custom optimization logic.
Combines ragged (variable-length) sequence processing with paged KV cache management for maximum memory efficiency in serving workloads.
- Features:
Zero padding overhead through ragged layout
Efficient paged KV cache management
Support for variable context lengths per sequence
Sliding window attention for long contexts
Logit soft capping for numerical stability
Attention sink mechanism for improved long-context performance
Configurable block sizes and memory limits
Multiple platform support (Triton/Pallas/CUDA/XLA)
- This implementation is particularly efficient for:
LLM serving with dynamic batching
Variable-length inference workloads
Memory-constrained deployment
Scenarios requiring efficient KV cache sharing
The ragged layout eliminates padding overhead while paged cache enables flexible memory management and sharing.
- candidate_cfgs(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#
Generate candidate configurations for autotuning.
Creates configurations optimized for ragged attention scenarios with various batch sizes and sequence lengths.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Ragged attention performance depends on the distribution of sequence lengths and the page size. Candidates are chosen to work well across common serving scenarios.
- candidate_cfgs_gpu(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#
Generate candidate configurations for autotuning on GPU (Triton).
Heuristics:
- candidate_cfgs_shard_map_gpu(inv: Invocation[RaggedPageAttentionv2Config, Array])#
Generate candidate configurations for autotuning on GPU (Triton).
Heuristics:
- candidate_cfgs_shard_map_tpu(inv: Invocation[RaggedPageAttentionv2Config, Array])#
Generate candidate configurations for autotuning on TPU (Pallas backend).
Heuristics: - For small head_dim, larger BLOCK_M is fine (64-128). - For large head_dim (>=160), prefer smaller BLOCK_M (32-64). - More KV pages per block helps small page_size (<=32). - Constrain S_block = page_size * num_kv_pages_per_block <= 256 to keep tiles reasonable.
- candidate_cfgs_tpu(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#
Generate candidate configurations for autotuning on TPU (Pallas backend).
Heuristics: - For small head_dim, larger BLOCK_M is fine (64-128). - For large head_dim (>=160), prefer smaller BLOCK_M (32-64). - More KV pages per block helps small page_size (<=32). - Constrain S_block = page_size * num_kv_pages_per_block <= 256 to keep tiles reasonable.
- create_shard_map_wrapper(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedPageAttentionv2Config | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper specifically for ragged page attention.
Ragged page attention handles variable-length sequences with paged KV cache, ideal for serving scenarios.
- Parameters
queries – Flattened queries [total_tokens, num_q_heads, head_dim]
kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim]
context_lens – Context lengths [num_seqs]
block_tables – Block mapping [num_seqs, pages_per_seq]
query_start_loc – Start locations [num_seqs + 1]
num_seqs – Number of sequences
args (All other) – Ragged page attention parameters to be fixed
mesh – JAX device mesh
in_specs – Input partition specs (for queries, kv_pages, context_lens, block_tables, query_start_loc, num_seqs, softmax_aux)
out_specs – Output partition spec
- Returns
Tuple of (shard_map_fn, call_args)
- get_impl(cfg: RaggedPageAttentionv2Config)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend preferences
- Returns
Callable kernel implementation for ragged page attention
- Raises
ValueError – If no matching implementation is found for the configuration
- heuristic_cfg(inv: Invocation[RaggedPageAttentionv2Config, Array]) RaggedPageAttentionv2Config[source]#
Provide default configuration optimized for ragged page attention.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration with conservative block sizes suitable for typical ragged attention workloads with variable sequence lengths
- run(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, *, cfg: ~ejkernel.modules.operations.configs.RaggedPageAttentionv2Config) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute ragged page attention over variable-length sequences.
Computes attention where queries are in ragged (concatenated) format and KV cache is organized in pages, providing maximum memory efficiency.
- Parameters
queries – Ragged query tensor [total_tokens, num_q_heads, head_dim] All sequences concatenated without padding
kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim] Combined key-value cache in page format
context_lens – Actual context length per sequence [num_seqs]
block_tables – Page mapping [num_seqs, pages_per_seq] mapping logical pages to physical page indices
query_start_loc – Start indices for each sequence in queries [num_seqs + 1] query_start_loc[i] to query_start_loc[i+1] defines sequence i
num_seqs – Number of sequences in the batch
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
logits_soft_cap – Optional soft cap to bound attention logits
compute_dtype – Data type for computation (default: bfloat16)
optimized – Use optimized kernel implementation
sliding_window – Window size for local attention (None for full attention)
softmax_aux – Optional attention sink logits for long-context handling
mask_value – Value to use for masked positions (default: -inf)
vmem_limit_bytes – Memory limit for vector memory in bytes (TPU-specific)
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object containing num_kv_pages_per_block and num_queries_per_block
- Returns
Attention output [total_tokens, num_q_heads, head_dim] in ragged format
Note
The ragged format eliminates all padding overhead. Combined with paged KV cache, this provides the most memory-efficient attention implementation for serving workloads with variable-length sequences.
Example
>>> >>> query_start_loc = jnp.array([0, 10, 25]) >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, ... block_tables, query_start_loc, num_seqs=2 ... )
- ejkernel.modules.operations.ragged_page_attention_v2.ragged_page_attention_v2(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, /, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedPageAttentionv2Config | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute ragged page attention with automatic optimization.
Ragged page attention efficiently handles variable-length sequences in a single batch using flattened token representation and page-based KV cache.
- Parameters
queries – Flattened query tensor [total_tokens, num_q_heads, head_dim]
kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim]
context_lens – Context length per sequence [num_seqs]
block_tables – Block mapping table [num_seqs, pages_per_seq]
query_start_loc – Start locations for each sequence [num_seqs + 1]
num_seqs – Number of sequences in the batch
softmax_scale – Softmax scaling factor
logits_soft_cap – Soft capping value for logits
compute_dtype – Computation dtype (default: bfloat16)
optimized – Use optimized implementation
sliding_window – Sliding window size for local attention
softmax_aux – Attention sink logits
mask_value – Value for masked positions
vmem_limit_bytes – Memory limit in bytes
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional config override (num_kv_pages_per_block and num_queries_per_block are set via cfg)
mesh – JAX device mesh for shard_map execution (optional)
in_specs – Input partition specs for shard_map (optional)
out_specs – Output partition spec for shard_map (optional)
- Returns
Attention output [total_tokens, num_q_heads, head_dim]
Example
>>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs ... ) >>> >>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs, sliding_window=256 ... ) >>> >>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs, optimized=True, logits_soft_cap=50.0 ... ) >>> >>> >>> out = ragged_page_attention_v2(..., platform="triton")