# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Page Attention module with automatic optimization.
This module implements Page Attention, a specialized attention mechanism designed
for efficient KV cache management in serving and inference workloads. Page Attention
organizes the KV cache in fixed-size blocks (pages), enabling:
- Dynamic memory allocation without pre-allocating for max sequence length
- Efficient memory sharing across sequences (e.g., for beam search or prefix caching)
- Reduced memory fragmentation compared to contiguous allocation
- Better GPU memory utilization through page-level management
Page Attention is particularly valuable for:
- LLM serving with variable-length sequences
- Batch inference with dynamic batching
- Memory-constrained deployment scenarios
- Systems requiring efficient KV cache sharing
Key Concepts:
Pages: Fixed-size blocks holding a portion of KV cache (e.g., 16 or 32 tokens)
Block Tables: Mapping from logical sequence positions to physical page indices
Context Lengths: Actual sequence lengths (excluding padding)
The paged approach enables:
- Near-zero memory waste (only last page per sequence may be partially filled)
- Easy insertion/deletion of sequences without memory reshuffling
- Natural support for prefix sharing in beam search
Mathematical Foundation:
For query position i:
output[i] = sum_{j in valid_pages} softmax(Q[i] @ K[pages[j]].T) @ V[pages[j]]
Where valid_pages are determined by block_tables and context_lens.
Memory Layout:
Instead of: [seq_len, num_heads, head_dim] (contiguous per sequence)
Use: [num_pages, page_size, num_heads, head_dim] (page-based allocation)
References:
Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention"
https://arxiv.org/abs/2309.06180 (vLLM paper)
"""
from __future__ import annotations
import os
from typing import Literal
from jax import shard_map
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, Float, Int
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
ConfigCache,
ConfigSelectorChain,
Executor,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ..base import detect_platform
from .configs import PageAttentionConfig
[docs]class PageAttention(Kernel[PageAttentionConfig, Array]):
"""Page Attention with custom optimization logic.
Efficient attention over paged KV cache for serving workloads.
Optimized for dynamic batching with variable context lengths.
Features:
- Paged KV cache management for memory efficiency
- Support for variable context lengths per sequence
- Automatic partitioning for long contexts
- Multi-split attention for improved throughput
- Optimized for inference and serving workloads
- Logit soft capping for numerical stability
- Configurable pages per compute block
- TPU megacore mode support
The paged layout provides:
- O(1) insertion/deletion of sequences
- Efficient prefix sharing for beam search
- Minimal memory fragmentation
- Better batch utilization through dynamic allocation
"""
def __init__(self):
"""Initialize Page Attention module.
Sets up the kernel for paged KV cache attention computation
with automatic platform selection and optimization.
"""
super().__init__(op_id="page_attention")
[docs] def get_impl(self, cfg: PageAttentionConfig):
"""Get kernel implementation from registry.
Args:
cfg: Configuration specifying platform and backend preferences
Returns:
Callable kernel implementation for page attention
Raises:
ValueError: If no matching implementation is found for the configuration
"""
platform = detect_platform("page_attention", cfg.platform)
return kernel_registry.get("page_attention", platform=platform, backend=cfg.backend)
[docs] def run(
self,
query: Float[Array, "num_seqs num_heads head_dim"],
key_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
value_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
context_lens: Int[Array, "num_seqs"],
block_tables: Int[Array, "num_seqs max_blocks"],
attn_scale: float | None = None,
max_context_len: int | None = None,
num_splits: int = 0,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: PageAttentionConfig,
mask_value: float = -2.381976426469702e38,
attn_logits_soft_cap: float | None = None,
pages_per_compute_block: int | None = None,
megacore_mode: str | None = None,
inline_seq_dim: bool = True,
) -> Float[Array, "num_seqs num_heads head_dim"]:
"""Execute page attention over paged KV cache.
Computes attention where the KV cache is organized in fixed-size pages,
with each sequence's tokens potentially scattered across non-contiguous pages.
Args:
query: Query tensor [num_seqs, num_heads, head_dim] for current decode step
key_cache: Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache: Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens: Actual context length per sequence [num_seqs]
block_tables: Page index mapping [num_seqs, max_blocks] where block_tables[i, j]
gives the physical page index for sequence i's jth logical block
attn_scale: Attention score scaling factor (default: 1/sqrt(head_dim))
max_context_len: Maximum context length across all sequences
num_splits: Number of splits for partitioned attention (0 = auto, 1 = no split)
mask_value: Value used for masked positions (default: -inf)
attn_logits_soft_cap: Optional soft cap for attention logits
pages_per_compute_block: Number of pages to process per compute block
megacore_mode: TPU-specific megacore execution mode
inline_seq_dim: Whether to inline the sequence dimension
platform: Optional platform override ("triton", "pallas", "cuda", "xla")
cfg: Kernel configuration object
Returns:
Attention output [num_seqs, num_heads, head_dim]
Note:
Block tables define the mapping from logical to physical pages:
logical_page_idx = position // block_size
physical_page_idx = block_tables[seq_idx, logical_page_idx]
Example:
>>>
>>>
>>> block_tables = jnp.array([[3, 7, 0], [1, 5, 0]])
>>> context_lens = jnp.array([32, 24])
>>> out = page_attention(q, k_cache, v_cache, context_lens, block_tables)
"""
if platform is not None:
cfg = PageAttentionConfig(
num_splits=cfg.num_splits,
pages_per_compute_block=cfg.pages_per_compute_block,
num_warps=cfg.num_warps,
num_stages=cfg.num_stages,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
return impl(
query=query,
key_cache=key_cache,
value_cache=value_cache,
context_lens=context_lens,
block_tables=block_tables,
attn_scale=attn_scale,
max_context_len=max_context_len,
num_splits=num_splits,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
pages_per_compute_block=pages_per_compute_block,
megacore_mode=megacore_mode,
inline_seq_dim=inline_seq_dim,
)
[docs] def heuristic_cfg(self, inv: Invocation[PageAttentionConfig, Array]) -> PageAttentionConfig:
"""Provide default configuration optimized for paged attention.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Default KernelConfig with block sizes suitable for typical
serving workloads with variable context lengths
"""
return PageAttentionConfig(
num_splits=0,
pages_per_compute_block=None,
num_warps=4,
num_stages=1,
platform="auto",
backend="any",
)
[docs] def candidate_cfgs(self, inv: Invocation[PageAttentionConfig, Array]):
"""Generate candidate configurations for autotuning.
Creates configurations optimized for different batch sizes and
context lengths commonly seen in serving scenarios.
Args:
inv: Invocation object containing arguments and metadata
Returns:
List of candidate configurations to benchmark during autotuning
Note:
Page attention doesn't have tunable block sizes in the traditional sense.
The num_splits and pages_per_compute_block are auto-determined.
"""
return []
[docs] def create_shard_map_wrapper(
self,
query: Float[Array, "num_seqs num_heads head_dim"],
key_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
value_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
context_lens: Int[Array, "num_seqs"],
block_tables: Int[Array, "num_seqs max_blocks"],
attn_scale: float | None = None,
max_context_len: int | None = None,
num_splits: int = 0,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: PageAttentionConfig | None = None,
mask_value: float = -2.381976426469702e38,
attn_logits_soft_cap: float | None = None,
pages_per_compute_block: int | None = None,
megacore_mode: str | None = None,
inline_seq_dim: bool = True,
mesh: Mesh | None = None,
in_specs: tuple[PartitionSpec, ...] | None = None,
out_specs: PartitionSpec | None = None,
check_vma: bool = False,
):
"""Create a shard_map wrapper for distributed execution.
Creates a wrapper function that applies shard_map to distribute the page attention
computation across devices according to the provided sharding specifications.
Args:
query: Query tensor [num_seqs, num_heads, head_dim]
key_cache: Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache: Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens: Context length per sequence [num_seqs]
block_tables: Block mapping table [num_seqs, max_blocks]
attn_scale: Attention scaling factor
max_context_len: Maximum context length across all sequences
num_splits: Number of splits for partitioned attention
platform: Platform to use for execution
cfg: Configuration for the kernel
mask_value: Value for masked positions
attn_logits_soft_cap: Soft cap value for attention logits
pages_per_compute_block: Pages per compute block
megacore_mode: Megacore execution mode
inline_seq_dim: Whether to inline sequence dimension
mesh: JAX mesh for distributed execution
in_specs: Partition specifications for input tensors
out_specs: Partition specifications for output tensor
check_vma: Whether to check for valid memory access patterns
Returns:
Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped
function and call_args are the arguments to pass to it.
"""
assert mesh is not None, "mesh must be provided for shard_map execution"
assert in_specs is not None, "in_specs must be provided for shard_map execution"
assert out_specs is not None, "out_specs must be provided for shard_map execution"
def _wrapper(
query,
key_cache,
value_cache,
context_lens,
block_tables,
):
return self.run(
query=query,
key_cache=key_cache,
value_cache=value_cache,
context_lens=context_lens,
block_tables=block_tables,
attn_scale=attn_scale,
max_context_len=max_context_len,
num_splits=num_splits,
platform=platform,
cfg=cfg or self.heuristic_cfg(None),
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
pages_per_compute_block=pages_per_compute_block,
megacore_mode=megacore_mode,
inline_seq_dim=inline_seq_dim,
)
shard_map_fn = shard_map(
_wrapper,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
call_args = (
query,
key_cache,
value_cache,
context_lens,
block_tables,
)
return shard_map_fn, call_args
_page_attention_executor: Executor[PageAttentionConfig, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(
allow_autotune=True,
cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"),
validate_backward=False,
),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("page-attention"),
)
)
[docs]def page_attention(
query: Float[Array, "num_seqs num_heads head_dim"],
key_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
value_cache: Float[Array, "num_blocks num_kv_heads block_size head_dim"],
context_lens: Int[Array, "num_seqs"],
block_tables: Int[Array, "num_seqs max_blocks"],
/,
*,
attn_scale: float | None = None,
max_context_len: int | None = None,
num_splits: int = 0,
mask_value: float = -2.381976426469702e38,
attn_logits_soft_cap: float | None = None,
pages_per_compute_block: int | None = None,
megacore_mode: str | None = None,
inline_seq_dim: bool = True,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: PageAttentionConfig | None = None,
) -> Float[Array, "num_seqs num_heads head_dim"]:
"""Execute page attention with automatic optimization.
Page attention performs efficient attention computation over paged KV cache
for serving and inference workloads with dynamic batching.
Args:
query: Query tensor [num_seqs, num_heads, head_dim]
key_cache: Paged key cache [num_blocks, num_kv_heads, block_size, head_dim]
value_cache: Paged value cache [num_blocks, num_kv_heads, block_size, head_dim]
context_lens: Context length per sequence [num_seqs]
block_tables: Block mapping table [num_seqs, max_blocks]
attn_scale: Attention scaling factor
max_context_len: Maximum context length across all sequences
num_splits: Number of splits for partitioned attention (0=auto)
mask_value: Value for masked positions (default: -inf)
attn_logits_soft_cap: Soft cap value for attention logits
pages_per_compute_block: Pages per compute block
megacore_mode: Megacore execution mode
inline_seq_dim: Whether to inline sequence dimension
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
cfg: Optional configuration override
Returns:
Attention output [num_seqs, num_heads, head_dim]
Example:
>>>
>>> out = page_attention(query, key_cache, value_cache, context_lens, block_tables)
>>>
>>>
>>> out = page_attention(
... query, key_cache, value_cache, context_lens, block_tables,
... num_splits=4, max_context_len=8192
... )
>>>
>>>
>>> out = page_attention(
... query, key_cache, value_cache, context_lens, block_tables,
... attn_logits_soft_cap=50.0
... )
>>>
>>>
>>> out = page_attention(..., platform="triton")
"""
return _page_attention_executor(
PageAttention(),
query=query,
key_cache=key_cache,
value_cache=value_cache,
context_lens=context_lens,
block_tables=block_tables,
attn_scale=attn_scale,
max_context_len=max_context_len,
num_splits=num_splits,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
pages_per_compute_block=pages_per_compute_block,
megacore_mode=megacore_mode,
inline_seq_dim=inline_seq_dim,
platform=platform,
_cfg=cfg,
)