ejkernel.modules.operations.page_attention#
Page Attention module with automatic optimization.
This module implements Page Attention, a specialized attention mechanism designed for efficient KV cache management in serving and inference workloads. Page Attention organizes the KV cache in fixed-size blocks (pages), enabling:
Dynamic memory allocation without pre-allocating for max sequence length
Efficient memory sharing across sequences (e.g., for beam search or prefix caching)
Reduced memory fragmentation compared to contiguous allocation
Better GPU memory utilization through page-level management
- Page Attention is particularly valuable for:
LLM serving with variable-length sequences
Batch inference with dynamic batching
Memory-constrained deployment scenarios
Systems requiring efficient KV cache sharing
- Key Concepts:
Pages: Fixed-size blocks holding a portion of KV cache (e.g., 16 or 32 tokens) Block Tables: Mapping from logical sequence positions to physical page indices Context Lengths: Actual sequence lengths (excluding padding)
- The paged approach enables:
Near-zero memory waste (only last page per sequence may be partially filled)
Easy insertion/deletion of sequences without memory reshuffling
Natural support for prefix sharing in beam search
- Mathematical Foundation:
- For query position i:
output[i] = sum_{j in valid_pages} softmax(Q[i] @ K[pages[j]].T) @ V[pages[j]]
Where valid_pages are determined by block_tables and context_lens.
- Memory Layout:
Instead of: [seq_len, num_heads, head_dim] (contiguous per sequence) Use: [num_pages, page_size, num_heads, head_dim] (page-based allocation)
References
Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention” https://arxiv.org/abs/2309.06180 (vLLM paper)
- class ejkernel.modules.operations.page_attention.PageAttention[source]#
Bases:
Kernel[PageAttentionConfig,Array]Page Attention with custom optimization logic.
Efficient attention over paged KV cache for serving workloads. Optimized for dynamic batching with variable context lengths.
- Features:
Paged KV cache management for memory efficiency
Support for variable context lengths per sequence
Automatic partitioning for long contexts
Multi-split attention for improved throughput
Optimized for inference and serving workloads
Logit soft capping for numerical stability
Configurable pages per compute block
TPU megacore mode support
- The paged layout provides:
O(1) insertion/deletion of sequences
Efficient prefix sharing for beam search
Minimal memory fragmentation
Better batch utilization through dynamic allocation
- candidate_cfgs(inv: Invocation[PageAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates configurations optimized for different batch sizes and context lengths commonly seen in serving scenarios.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Page attention doesn’t have tunable block sizes in the traditional sense. The num_splits and pages_per_compute_block are auto-determined.
- create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.PageAttentionConfig | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper for distributed execution.
Creates a wrapper function that applies shard_map to distribute the page attention computation across devices according to the provided sharding specifications.
- Parameters
query – Query tensor [num_seqs, num_heads, head_dim]
key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens – Context length per sequence [num_seqs]
block_tables – Block mapping table [num_seqs, max_blocks]
attn_scale – Attention scaling factor
max_context_len – Maximum context length across all sequences
num_splits – Number of splits for partitioned attention
platform – Platform to use for execution
cfg – Configuration for the kernel
mask_value – Value for masked positions
attn_logits_soft_cap – Soft cap value for attention logits
pages_per_compute_block – Pages per compute block
megacore_mode – Megacore execution mode
inline_seq_dim – Whether to inline sequence dimension
mesh – JAX mesh for distributed execution
in_specs – Partition specifications for input tensors
out_specs – Partition specifications for output tensor
check_vma – Whether to check for valid memory access patterns
- Returns
Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped function and call_args are the arguments to pass to it.
- get_impl(cfg: PageAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend preferences
- Returns
Callable kernel implementation for page attention
- Raises
ValueError – If no matching implementation is found for the configuration
- heuristic_cfg(inv: Invocation[PageAttentionConfig, Array]) PageAttentionConfig[source]#
Provide default configuration optimized for paged attention.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default KernelConfig with block sizes suitable for typical serving workloads with variable context lengths
- run(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: PageAttentionConfig, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#
Execute page attention over paged KV cache.
Computes attention where the KV cache is organized in fixed-size pages, with each sequence’s tokens potentially scattered across non-contiguous pages.
- Parameters
query – Query tensor [num_seqs, num_heads, head_dim] for current decode step
key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens – Actual context length per sequence [num_seqs]
block_tables – Page index mapping [num_seqs, max_blocks] where block_tables[i, j] gives the physical page index for sequence i’s jth logical block
attn_scale – Attention score scaling factor (default: 1/sqrt(head_dim))
max_context_len – Maximum context length across all sequences
num_splits – Number of splits for partitioned attention (0 = auto, 1 = no split)
mask_value – Value used for masked positions (default: -inf)
attn_logits_soft_cap – Optional soft cap for attention logits
pages_per_compute_block – Number of pages to process per compute block
megacore_mode – TPU-specific megacore execution mode
inline_seq_dim – Whether to inline the sequence dimension
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Attention output [num_seqs, num_heads, head_dim]
Note
- Block tables define the mapping from logical to physical pages:
logical_page_idx = position // block_size physical_page_idx = block_tables[seq_idx, logical_page_idx]
Example
>>> >>> >>> block_tables = jnp.array([[3, 7, 0], [1, 5, 0]]) >>> context_lens = jnp.array([32, 24]) >>> out = page_attention(q, k_cache, v_cache, context_lens, block_tables)
- ejkernel.modules.operations.page_attention.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks num_kv_heads block_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], /, *, attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.PageAttentionConfig | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#
Execute page attention with automatic optimization.
Page attention performs efficient attention computation over paged KV cache for serving and inference workloads with dynamic batching.
- Parameters
query – Query tensor [num_seqs, num_heads, head_dim]
key_cache – Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache – Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens – Context length per sequence [num_seqs]
block_tables – Block mapping table [num_seqs, max_blocks]
attn_scale – Attention scaling factor
max_context_len – Maximum context length across all sequences
num_splits – Number of splits for partitioned attention (0=auto)
mask_value – Value for masked positions (default: -inf)
attn_logits_soft_cap – Soft cap value for attention logits
pages_per_compute_block – Pages per compute block
megacore_mode – Megacore execution mode
inline_seq_dim – Whether to inline sequence dimension
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional configuration override
- Returns
Attention output [num_seqs, num_heads, head_dim]
Example
>>> >>> out = page_attention(query, key_cache, value_cache, context_lens, block_tables) >>> >>> >>> out = page_attention( ... query, key_cache, value_cache, context_lens, block_tables, ... num_splits=4, max_context_len=8192 ... ) >>> >>> >>> out = page_attention( ... query, key_cache, value_cache, context_lens, block_tables, ... attn_logits_soft_cap=50.0 ... ) >>> >>> >>> out = page_attention(..., platform="triton")