ejkernel.kernels._xla.page_attention._interface

Contents

ejkernel.kernels._xla.page_attention._interface#

Paged attention interface for efficient KV cache management.

This module provides the public API for paged attention where the KV cache is organized into fixed-size blocks. Enables efficient memory management for variable-length sequences in decode/generation.

ejkernel.kernels._xla.page_attention._interface.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#

Paged attention for efficient KV cache management using JAX/XLA.

This function implements paged attention where the KV cache is organized into fixed-size blocks (pages). Each sequence maintains a block table that maps logical KV positions to physical block indices. This enables efficient memory management for variable-length sequences and dynamic batching.

Parameters
  • query – Query tensor of shape [num_seqs, num_heads, head_dim]. Each sequence has a single query token (typically for decode/generation).

  • key_cache – Paged key cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. The total KV cache is divided into blocks of size block_size.

  • value_cache – Paged value cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. Must have the same structure as key_cache.

  • context_lens – Context length for each sequence [num_seqs]. Indicates how many tokens are valid in the KV cache for each sequence.

  • block_tables – Block table mapping [num_seqs, max_blocks]. For each sequence, maps logical block indices to physical block indices in the cache.

  • attn_scale – Attention scaling factor. If None, defaults to 1/sqrt(head_dim).

  • max_context_len – Maximum context length (not used in XLA implementation).

  • num_splits – Number of splits for partitioned attention (not used in XLA implementation).

  • mask_value – Value used for masking (not used in XLA implementation).

  • attn_logits_soft_cap – Soft cap for attention logits (not used in XLA implementation).

  • pages_per_compute_block – Pages per compute block (not used in XLA implementation).

  • megacore_mode – Megacore parallelization mode (not used in XLA implementation).

  • inline_seq_dim – Whether to inline sequence dimension (not used in XLA implementation).

Returns

Attention output of shape [num_seqs, num_heads, head_dim].

Notes

  • Supports Grouped Query Attention (GQA) where num_heads >= num_kv_heads

  • Each sequence can use a different number of blocks based on context_lens

  • Blocks are indexed via block_tables to avoid fragmentation

  • This is a simpler version compared to ragged_page_attention_v2 which handles multiple query tokens per sequence

Examples

>>> num_seqs, num_heads, head_dim = 2, 8, 64
>>> num_kv_heads = 8
>>> num_blocks, block_size = 10, 16
>>>
>>> query = jnp.ones((num_seqs, num_heads, head_dim))
>>> key_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>> value_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>> context_lens = jnp.array([48, 32])
>>> block_tables = jnp.array([[0, 1, 2, -1], [3, 4, -1, -1]])
>>>
>>> output = page_attention(query, key_cache, value_cache,
...                         context_lens, block_tables)
>>> output.shape
(2, 8, 64)