# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int32
from ..._registry import Backend, Platform, kernel_registry
from ._triton_impl_fwd import ragged_paged_attention_triton
[docs]@kernel_registry.register("ragged_page_attention_v3", Platform.TRITON, Backend.GPU)
@jaxtyping.jaxtyped(typechecker=beartype)
def ragged_page_attention_v3(
queries: Float[Array, "total_tokens num_q_heads head_dim"],
keys: Float[Array, "total_tokens num_kv_heads head_dim"],
values: Float[Array, "total_tokens num_kv_heads head_dim"],
kv_cache: Float[Array, "num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded"],
kv_lens: Int32[Array, "max_num_seqs"],
block_tables: Int32[Array, "max_num_seqs_times_pages_per_seq"],
query_start_loc: Int32[Array, "max_num_seqs_plus_1"],
distribution: Int32[Array, "3"],
attention_sink: Float[Array, "num_q_heads"] | None = None,
*,
softmax_scale: float = 1.0,
sliding_window: int | None = None,
logits_soft_cap: float | None = None,
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
chunk_prefill_size: int | None = None,
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
) -> tuple[
Float[Array, "total_tokens num_q_heads head_dim"],
Float[Array, "num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded"],
]:
if softmax_scale is None:
softmax_scale = queries.shape[-1] ** -0.5
return ragged_paged_attention_triton(
queries,
keys,
values,
kv_cache,
kv_lens,
block_tables,
query_start_loc,
distribution,
attention_sink,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
chunk_prefill_size=chunk_prefill_size,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes,
)