ejkernel.modules.operations.page_attention#

Page Attention module with automatic optimization.

This module implements Page Attention, a specialized attention mechanism designed for efficient KV cache management in serving and inference workloads. Page Attention organizes the KV cache in fixed-size blocks (pages), enabling:

  • Dynamic memory allocation without pre-allocating for max sequence length

  • Efficient memory sharing across sequences (e.g., for beam search or prefix caching)

  • Reduced memory fragmentation compared to contiguous allocation

  • Better GPU memory utilization through page-level management

Page Attention is particularly valuable for:
  • LLM serving with variable-length sequences

  • Batch inference with dynamic batching

  • Memory-constrained deployment scenarios

  • Systems requiring efficient KV cache sharing

Key Concepts:

Pages: Fixed-size blocks holding a portion of KV cache (e.g., 16 or 32 tokens) Block Tables: Mapping from logical sequence positions to physical page indices Context Lengths: Actual sequence lengths (excluding padding)

The paged approach enables:
  • Near-zero memory waste (only last page per sequence may be partially filled)

  • Easy insertion/deletion of sequences without memory reshuffling

  • Natural support for prefix sharing in beam search

Mathematical Foundation:
For query position i:

output[i] = sum_{j in valid_pages} softmax(Q[i] @ K[pages[j]].T) @ V[pages[j]]

Where valid_pages are determined by block_tables and context_lens.

Memory Layout:

Instead of: [seq_len, num_heads, head_dim] (contiguous per sequence) Use: [num_pages, page_size, num_heads, head_dim] (page-based allocation)

References

Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention” https://arxiv.org/abs/2309.06180 (vLLM paper)

class ejkernel.modules.operations.page_attention.PageAttention[source]#

Bases: Kernel[PageAttentionConfig, Array]

Page Attention with custom optimization logic.

Efficient attention over paged KV cache for serving workloads. Optimized for dynamic batching with variable context lengths.

Features:
  • Paged KV cache management for memory efficiency

  • Support for variable context lengths per sequence

  • Automatic partitioning for long contexts

  • Multi-split attention for improved throughput

  • Optimized for inference and serving workloads

  • Logit soft capping for numerical stability

  • Configurable pages per compute block

  • TPU megacore mode support

The paged layout provides:
  • O(1) insertion/deletion of sequences

  • Efficient prefix sharing for beam search

  • Minimal memory fragmentation

  • Better batch utilization through dynamic allocation

candidate_cfgs(inv: Invocation[PageAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates configurations optimized for different batch sizes and context lengths commonly seen in serving scenarios.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Page attention doesn’t have tunable block sizes in the traditional sense. The num_splits and pages_per_compute_block are auto-determined.

create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.PageAttentionConfig | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper for distributed execution.

Creates a wrapper function that applies shard_map to distribute the page attention computation across devices according to the provided sharding specifications.

Parameters
  • query – Query tensor [num_seqs, num_heads, head_dim]

  • key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]

  • value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]

  • context_lens – Context length per sequence [num_seqs]

  • block_tables – Block mapping table [num_seqs, max_blocks]

  • attn_scale – Attention scaling factor

  • max_context_len – Maximum context length across all sequences

  • num_splits – Number of splits for partitioned attention

  • platform – Platform to use for execution

  • cfg – Configuration for the kernel

  • mask_value – Value for masked positions

  • attn_logits_soft_cap – Soft cap value for attention logits

  • pages_per_compute_block – Pages per compute block

  • megacore_mode – Megacore execution mode

  • inline_seq_dim – Whether to inline sequence dimension

  • mesh – JAX mesh for distributed execution

  • in_specs – Partition specifications for input tensors

  • out_specs – Partition specifications for output tensor

  • check_vma – Whether to check for valid memory access patterns

Returns

Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped function and call_args are the arguments to pass to it.

get_impl(cfg: PageAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for page attention

Raises

ValueError – If no matching implementation is found for the configuration

heuristic_cfg(inv: Invocation[PageAttentionConfig, Array]) PageAttentionConfig[source]#

Provide default configuration optimized for paged attention.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default KernelConfig with block sizes suitable for typical serving workloads with variable context lengths

run(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: PageAttentionConfig, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#

Execute page attention over paged KV cache.

Computes attention where the KV cache is organized in fixed-size pages, with each sequence’s tokens potentially scattered across non-contiguous pages.

Parameters
  • query – Query tensor [num_seqs, num_heads, head_dim] for current decode step

  • key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]

  • value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]

  • context_lens – Actual context length per sequence [num_seqs]

  • block_tables – Page index mapping [num_seqs, max_blocks] where block_tables[i, j] gives the physical page index for sequence i’s jth logical block

  • attn_scale – Attention score scaling factor (default: 1/sqrt(head_dim))

  • max_context_len – Maximum context length across all sequences

  • num_splits – Number of splits for partitioned attention (0 = auto, 1 = no split)

  • mask_value – Value used for masked positions (default: -inf)

  • attn_logits_soft_cap – Optional soft cap for attention logits

  • pages_per_compute_block – Number of pages to process per compute block

  • megacore_mode – TPU-specific megacore execution mode

  • inline_seq_dim – Whether to inline the sequence dimension

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Attention output [num_seqs, num_heads, head_dim]

Note

Block tables define the mapping from logical to physical pages:

logical_page_idx = position // block_size physical_page_idx = block_tables[seq_idx, logical_page_idx]

Example

>>>
>>>
>>> block_tables = jnp.array([[3, 7, 0], [1, 5, 0]])
>>> context_lens = jnp.array([32, 24])
>>> out = page_attention(q, k_cache, v_cache, context_lens, block_tables)
ejkernel.modules.operations.page_attention.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], /, *, attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.PageAttentionConfig | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#

Execute page attention with automatic optimization.

Page attention performs efficient attention computation over paged KV cache for serving and inference workloads with dynamic batching.

Parameters
  • query – Query tensor [num_seqs, num_heads, head_dim]

  • key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]

  • value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]

  • context_lens – Context length per sequence [num_seqs]

  • block_tables – Block mapping table [num_seqs, max_blocks]

  • attn_scale – Attention scaling factor

  • max_context_len – Maximum context length across all sequences

  • num_splits – Number of splits for partitioned attention (0=auto)

  • mask_value – Value for masked positions (default: -inf)

  • attn_logits_soft_cap – Soft cap value for attention logits

  • pages_per_compute_block – Pages per compute block

  • megacore_mode – Megacore execution mode

  • inline_seq_dim – Whether to inline sequence dimension

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional configuration override

Returns

Attention output [num_seqs, num_heads, head_dim]

Example

>>>
>>> out = page_attention(query, key_cache, value_cache, context_lens, block_tables)
>>>
>>>
>>> out = page_attention(
...     query, key_cache, value_cache, context_lens, block_tables,
...     num_splits=4, max_context_len=8192
... )
>>>
>>>
>>> out = page_attention(
...     query, key_cache, value_cache, context_lens, block_tables,
...     attn_logits_soft_cap=50.0
... )
    >>>
>>>
>>> out = page_attention(..., platform="triton")