Source code for ejkernel.kernels._xla.page_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Paged attention interface for efficient KV cache management.

This module provides the public API for paged attention where the KV cache
is organized into fixed-size blocks. Enables efficient memory management
for variable-length sequences in decode/generation.
"""

import jax.numpy as jnp
import jaxtyping
import numpy as np
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_fwd import _page_attention_fwd

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)


[docs]@kernel_registry.register("page_attention", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def page_attention( query: Float[Array, "num_seqs num_heads head_dim"], key_cache: Float[Array, "num_kv_heads total_num_pages page_size head_dim"], value_cache: Float[Array, "num_kv_heads total_num_pages page_size head_dim"], context_lens: Int[Array, "num_seqs"], block_tables: Int[Array, "num_seqs max_blocks"], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, *, mask_value: float = DEFAULT_MASK_VALUE, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, sliding_window: int | None = None, ) -> Float[Array, "num_seqs num_heads head_dim"]: """ Paged attention for efficient KV cache management using JAX/XLA. This function implements paged attention where the KV cache is organized into fixed-size blocks (pages). Each sequence maintains a block table that maps logical KV positions to physical block indices. This enables efficient memory management for variable-length sequences and dynamic batching. Args: query: Query tensor of shape [num_seqs, num_heads, head_dim]. Each sequence has a single query token (typically for decode/generation). key_cache: Paged key cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. The total KV cache is divided into blocks of size `block_size`. value_cache: Paged value cache of shape [num_blocks, num_kv_heads, block_size, head_dim]. Must have the same structure as key_cache. context_lens: Context length for each sequence [num_seqs]. Indicates how many tokens are valid in the KV cache for each sequence. block_tables: Block table mapping [num_seqs, max_blocks]. For each sequence, maps logical block indices to physical block indices in the cache. attn_scale: Attention scaling factor. If None, defaults to 1/sqrt(head_dim). max_context_len: Maximum context length (not used in XLA implementation). num_splits: Number of splits for partitioned attention (not used in XLA implementation). mask_value: Value used for masking (not used in XLA implementation). attn_logits_soft_cap: Soft cap for attention logits (not used in XLA implementation). pages_per_compute_block: Pages per compute block (not used in XLA implementation). megacore_mode: Megacore parallelization mode (not used in XLA implementation). inline_seq_dim: Whether to inline sequence dimension (not used in XLA implementation). Returns: Attention output of shape [num_seqs, num_heads, head_dim]. Notes: - Supports Grouped Query Attention (GQA) where num_heads >= num_kv_heads - Each sequence can use a different number of blocks based on context_lens - Blocks are indexed via block_tables to avoid fragmentation - This is a simpler version compared to ragged_page_attention_v2 which handles multiple query tokens per sequence Examples: >>> num_seqs, num_heads, head_dim = 2, 8, 64 >>> num_kv_heads = 8 >>> num_blocks, block_size = 10, 16 >>> >>> query = jnp.ones((num_seqs, num_heads, head_dim)) >>> key_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim)) >>> value_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim)) >>> context_lens = jnp.array([48, 32]) >>> block_tables = jnp.array([[0, 1, 2, -1], [3, 4, -1, -1]]) >>> >>> output = page_attention(query, key_cache, value_cache, ... context_lens, block_tables) >>> output.shape (2, 8, 64) """ if max_context_len is not None: raise NotImplementedError("max_context_len is not supported in XLA implementation") if num_splits != 0: raise NotImplementedError("num_splits is not supported in XLA implementation") if pages_per_compute_block is not None: raise NotImplementedError("pages_per_compute_block is not supported in XLA implementation") if megacore_mode is not None: raise NotImplementedError("megacore_mode is not supported in XLA implementation") if not inline_seq_dim: raise NotImplementedError("inline_seq_dim=False is not supported in XLA implementation") if attn_logits_soft_cap is not None: raise NotImplementedError("attn_logits_soft_cap is not supported in XLA implementation") if attn_scale is None: attn_scale = 1.0 / jnp.sqrt(query.shape[-1]).astype(jnp.float32) block_size = key_cache.shape[2] return _page_attention_fwd( query=query, key_cache=key_cache, value_cache=value_cache, context_lens=context_lens, block_tables=block_tables, attn_scale=attn_scale, block_size=block_size, )