ejkernel.modules.operations.ragged_page_attention_v2#

Ragged Page Attention module with automatic optimization.

This module implements ragged page attention, combining the benefits of both ragged (variable-length) sequence processing and paged KV cache management. This approach is particularly efficient for serving scenarios where sequences have variable lengths and KV cache is organized in fixed-size pages.

Ragged page attention addresses key challenges in LLM inference:
  • Variable-length sequences without padding overhead

  • Efficient memory management through paged KV cache

  • Dynamic batching with different sequence lengths

  • Memory sharing for beam search and prefix caching

Key Concepts:
Ragged Layout: Sequences are concatenated without padding, with start

locations tracking where each sequence begins

Pages: Fixed-size blocks holding portions of KV cache Block Tables: Mapping from logical sequence positions to physical pages

The combination provides:
  • Zero padding overhead (ragged layout)

  • Flexible memory allocation (paged cache)

  • Efficient batching of variable-length sequences

  • Support for dynamic sequence management

Memory Layout:

Queries: [total_tokens, num_heads, head_dim] (ragged, no padding) KV Cache: [num_pages, page_size, num_heads, head_dim] (paged)

Mathematical Foundation:
For token i in sequence s:

start_idx = query_start_loc[s] end_idx = query_start_loc[s + 1] output[i] = attention(Q[start_idx:end_idx], K[pages[s]], V[pages[s]])

This is the most memory-efficient attention variant for serving workloads.

class ejkernel.modules.operations.ragged_page_attention_v2.RaggedPageAttentionv2[source]#

Bases: Kernel[RaggedPageAttentionv2Config, Array]

Ragged Page Attention with custom optimization logic.

Combines ragged (variable-length) sequence processing with paged KV cache management for maximum memory efficiency in serving workloads.

Features:
  • Zero padding overhead through ragged layout

  • Efficient paged KV cache management

  • Support for variable context lengths per sequence

  • Sliding window attention for long contexts

  • Logit soft capping for numerical stability

  • Attention sink mechanism for improved long-context performance

  • Configurable block sizes and memory limits

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

This implementation is particularly efficient for:
  • LLM serving with dynamic batching

  • Variable-length inference workloads

  • Memory-constrained deployment

  • Scenarios requiring efficient KV cache sharing

The ragged layout eliminates padding overhead while paged cache enables flexible memory management and sharing.

candidate_cfgs(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#

Generate candidate configurations for autotuning.

Creates configurations optimized for ragged attention scenarios with various batch sizes and sequence lengths.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Ragged attention performance depends on the distribution of sequence lengths and the page size. Candidates are chosen to work well across common serving scenarios.

candidate_cfgs_gpu(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#

Generate candidate configurations for autotuning on GPU (Triton).

Heuristics:

candidate_cfgs_shard_map_gpu(inv: Invocation[RaggedPageAttentionv2Config, Array])#

Generate candidate configurations for autotuning on GPU (Triton).

Heuristics:

candidate_cfgs_shard_map_tpu(inv: Invocation[RaggedPageAttentionv2Config, Array])#

Generate candidate configurations for autotuning on TPU (Pallas backend).

Heuristics: - For small head_dim, larger BLOCK_M is fine (64-128). - For large head_dim (>=160), prefer smaller BLOCK_M (32-64). - More KV pages per block helps small page_size (<=32). - Constrain S_block = page_size * num_kv_pages_per_block <= 256 to keep tiles reasonable.

candidate_cfgs_tpu(inv: Invocation[RaggedPageAttentionv2Config, Array])[source]#

Generate candidate configurations for autotuning on TPU (Pallas backend).

Heuristics: - For small head_dim, larger BLOCK_M is fine (64-128). - For large head_dim (>=160), prefer smaller BLOCK_M (32-64). - More KV pages per block helps small page_size (<=32). - Constrain S_block = page_size * num_kv_pages_per_block <= 256 to keep tiles reasonable.

create_shard_map_wrapper(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedPageAttentionv2Config | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper specifically for ragged page attention.

Ragged page attention handles variable-length sequences with paged KV cache, ideal for serving scenarios.

Parameters
  • queries – Flattened queries [total_tokens, num_q_heads, head_dim]

  • kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim]

  • context_lens – Context lengths [num_seqs]

  • block_tables – Block mapping [num_seqs, pages_per_seq]

  • query_start_loc – Start locations [num_seqs + 1]

  • num_seqs – Number of sequences

  • args (All other) – Ragged page attention parameters to be fixed

  • mesh – JAX device mesh

  • in_specs – Input partition specs (for queries, kv_pages, context_lens, block_tables, query_start_loc, num_seqs, softmax_aux)

  • out_specs – Output partition spec

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: RaggedPageAttentionv2Config)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for ragged page attention

Raises

ValueError – If no matching implementation is found for the configuration

heuristic_cfg(inv: Invocation[RaggedPageAttentionv2Config, Array]) RaggedPageAttentionv2Config[source]#

Provide default configuration optimized for ragged page attention.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with conservative block sizes suitable for typical ragged attention workloads with variable sequence lengths

run(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, *, cfg: ~ejkernel.modules.operations.configs.RaggedPageAttentionv2Config) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute ragged page attention over variable-length sequences.

Computes attention where queries are in ragged (concatenated) format and KV cache is organized in pages, providing maximum memory efficiency.

Parameters
  • queries – Ragged query tensor [total_tokens, num_q_heads, head_dim] All sequences concatenated without padding

  • kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim] Combined key-value cache in page format

  • context_lens – Actual context length per sequence [num_seqs]

  • block_tables – Page mapping [num_seqs, pages_per_seq] mapping logical pages to physical page indices

  • query_start_loc – Start indices for each sequence in queries [num_seqs + 1] query_start_loc[i] to query_start_loc[i+1] defines sequence i

  • num_seqs – Number of sequences in the batch

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • logits_soft_cap – Optional soft cap to bound attention logits

  • compute_dtype – Data type for computation (default: bfloat16)

  • optimized – Use optimized kernel implementation

  • sliding_window – Window size for local attention (None for full attention)

  • softmax_aux – Optional attention sink logits for long-context handling

  • mask_value – Value to use for masked positions (default: -inf)

  • vmem_limit_bytes – Memory limit for vector memory in bytes (TPU-specific)

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object containing num_kv_pages_per_block and num_queries_per_block

Returns

Attention output [total_tokens, num_q_heads, head_dim] in ragged format

Note

The ragged format eliminates all padding overhead. Combined with paged KV cache, this provides the most memory-efficient attention implementation for serving workloads with variable-length sequences.

Example

>>>
>>> query_start_loc = jnp.array([0, 10, 25])
>>> out = ragged_page_attention_v2(
...     queries, kv_pages, context_lens,
...     block_tables, query_start_loc, num_seqs=2
... )
ejkernel.modules.operations.ragged_page_attention_v2.ragged_page_attention_v2(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, /, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RaggedPageAttentionv2Config | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute ragged page attention with automatic optimization.

Ragged page attention efficiently handles variable-length sequences in a single batch using flattened token representation and page-based KV cache.

Parameters
  • queries – Flattened query tensor [total_tokens, num_q_heads, head_dim]

  • kv_pages – Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim]

  • context_lens – Context length per sequence [num_seqs]

  • block_tables – Block mapping table [num_seqs, pages_per_seq]

  • query_start_loc – Start locations for each sequence [num_seqs + 1]

  • num_seqs – Number of sequences in the batch

  • softmax_scale – Softmax scaling factor

  • logits_soft_cap – Soft capping value for logits

  • compute_dtype – Computation dtype (default: bfloat16)

  • optimized – Use optimized implementation

  • sliding_window – Sliding window size for local attention

  • softmax_aux – Attention sink logits

  • mask_value – Value for masked positions

  • vmem_limit_bytes – Memory limit in bytes

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional config override (num_kv_pages_per_block and num_queries_per_block are set via cfg)

  • mesh – JAX device mesh for shard_map execution (optional)

  • in_specs – Input partition specs for shard_map (optional)

  • out_specs – Output partition spec for shard_map (optional)

Returns

Attention output [total_tokens, num_q_heads, head_dim]

Example

>>>
>>> out = ragged_page_attention_v2(
...     queries, kv_pages, context_lens, block_tables,
...     query_start_loc, num_seqs
... )
>>>
>>>
>>> out = ragged_page_attention_v2(
...     queries, kv_pages, context_lens, block_tables,
...     query_start_loc, num_seqs, sliding_window=256
... )
>>>
>>>
>>> out = ragged_page_attention_v2(
...     queries, kv_pages, context_lens, block_tables,
...     query_start_loc, num_seqs, optimized=True, logits_soft_cap=50.0
... )
>>>
>>>
>>> out = ragged_page_attention_v2(..., platform="triton")