ejkernel.kernels._xla.page_attention._interface#
Paged attention interface for efficient KV cache management.
This module provides the public API for paged attention where the KV cache is organized into fixed-size blocks. Enables efficient memory management for variable-length sequences in decode/generation.
- ejkernel.kernels._xla.page_attention._interface.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#
Paged attention for efficient KV cache management using JAX/XLA.
This function implements paged attention where the KV cache is organized into fixed-size blocks (pages). Each sequence maintains a block table that maps logical KV positions to physical block indices. This enables efficient memory management for variable-length sequences and dynamic batching.
- Parameters
query – Query tensor of shape [num_seqs, num_heads, head_dim]. Each sequence has a single query token (typically for decode/generation).
key_cache – Paged key cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. The total KV cache is divided into blocks of size block_size.
value_cache – Paged value cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. Must have the same structure as key_cache.
context_lens – Context length for each sequence [num_seqs]. Indicates how many tokens are valid in the KV cache for each sequence.
block_tables – Block table mapping [num_seqs, max_blocks]. For each sequence, maps logical block indices to physical block indices in the cache.
attn_scale – Attention scaling factor. If None, defaults to 1/sqrt(head_dim).
max_context_len – Maximum context length (not used in XLA implementation).
num_splits – Number of splits for partitioned attention (not used in XLA implementation).
mask_value – Value used for masking (not used in XLA implementation).
attn_logits_soft_cap – Soft cap for attention logits (not used in XLA implementation).
pages_per_compute_block – Pages per compute block (not used in XLA implementation).
megacore_mode – Megacore parallelization mode (not used in XLA implementation).
inline_seq_dim – Whether to inline sequence dimension (not used in XLA implementation).
- Returns
Attention output of shape [num_seqs, num_heads, head_dim].
Notes
Supports Grouped Query Attention (GQA) where num_heads >= num_kv_heads
Each sequence can use a different number of blocks based on context_lens
Blocks are indexed via block_tables to avoid fragmentation
This is a simpler version compared to ragged_page_attention_v2 which handles multiple query tokens per sequence
Examples
>>> num_seqs, num_heads, head_dim = 2, 8, 64 >>> num_kv_heads = 8 >>> num_blocks, block_size = 10, 16 >>> >>> query = jnp.ones((num_seqs, num_heads, head_dim)) >>> key_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim)) >>> value_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim)) >>> context_lens = jnp.array([48, 32]) >>> block_tables = jnp.array([[0, 1, 2, -1], [3, 4, -1, -1]]) >>> >>> output = page_attention(query, key_cache, value_cache, ... context_lens, block_tables) >>> output.shape (2, 8, 64)