Source code for ejkernel.kernels._xla.ragged_page_attention_v3._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ragged paged attention v3 interface for mixed prefill and decode.

This module provides the public API for the third-generation ragged paged
attention that supports mixed prefill and decode operations in a single
batch. Includes KV cache update functionality.
"""

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int32

from ..._registry import Backend, Platform, kernel_registry
from ._kernel import ragged_paged_attention as _ragged_paged_attention


[docs]@kernel_registry.register("ragged_page_attention_v3", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def ragged_page_attention_v3( queries: Float[Array, "total_tokens num_q_heads head_dim"], keys: Float[Array, "total_tokens num_kv_heads head_dim"], values: Float[Array, "total_tokens num_kv_heads head_dim"], kv_cache: Float[Array, "num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded"], kv_lens: Int32[Array, "max_num_seqs"], block_tables: Int32[Array, "max_num_seqs_times_pages_per_seq"], query_start_loc: Int32[Array, "max_num_seqs_plus_1"], distribution: Int32[Array, "3"], attention_sink: Float[Array, "num_q_heads"] | None = None, *, softmax_scale: float = 1.0, sliding_window: int | None = None, logits_soft_cap: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ) -> tuple[ Float[Array, "total_tokens num_q_heads head_dim"], Float[Array, "num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded"], ]: """Ragged paged attention that supports mixed prefill and decode. Args: queries: concatenated all sequences' queries. kv_pages: paged KV cache. Normally in HBM. context_lens: padded kv lengths. Only the first num_seqs values are valid. block_tables: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. query_start_loc: the cumulative sum of the effective query lengths. Similar to context_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. softmax_scale: the softmax softmax_scale which will be applied to the Q@K^T. sliding_window: the sliding window size for the attention. logits_soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. num_queries_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. vmem_limit_bytes: the vmem limit for the pallas kernel. Returns: The output of the attention. """ if softmax_scale is None: softmax_scale = queries.shape[-1] ** -0.5 return _ragged_paged_attention( queries, keys, values, kv_cache, kv_lens, block_tables, query_start_loc, distribution, attention_sink, softmax_scale=softmax_scale, sliding_window=sliding_window, logits_soft_cap=logits_soft_cap, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale, chunk_prefill_size=chunk_prefill_size, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, )