# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified attention interface for paged KV cache with mixed workloads.
This module provides the public API for unified attention that handles
ragged batches with paged key-value caches. Supports sliding window,
ALiBi slopes, and attention sink features.
"""
from __future__ import annotations
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int32
from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_fwd import _unified_attention_fwd
[docs]@kernel_registry.register("unified_attention", Platform.XLA, Backend.ANY)
@jaxtyping.jaxtyped(typechecker=beartype)
def unified_attention(
queries: Float[Array, "total_tokens num_q_heads head_dim"],
key_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
value_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
kv_lens: Int32[Array, "num_seqs"],
block_tables: Int32[Array, "num_seqs max_blocks_per_seq"],
query_start_loc: Int32[Array, "num_seqs_plus_1"],
*,
softmax_scale: float | None = None,
causal: bool = True,
sliding_window: int | None = None,
logits_soft_cap: float | None = None,
seq_threshold_3d: int | None = None,
num_par_softmax_segments: int | None = None,
alibi_slopes: Float[Array, "num_q_heads"] | None = None,
qq_bias: Float[Array, "num_query_tokens num_query_tokens"] | None = None,
attention_sink: Float[Array, "num_q_heads"] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
) -> Float[Array, "total_tokens num_q_heads head_dim"]:
del seq_threshold_3d, num_par_softmax_segments, num_warps, num_stages
if not causal:
raise NotImplementedError("unified_attention (XLA) only supports causal attention.")
if softmax_scale is None:
softmax_scale = queries.shape[-1] ** -0.5
return _unified_attention_fwd(
queries=queries,
key_cache=key_cache,
value_cache=value_cache,
kv_lens=kv_lens,
block_tables=block_tables,
query_start_loc=query_start_loc,
softmax_scale=float(softmax_scale),
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
alibi_slopes=alibi_slopes,
qq_bias=qq_bias,
attention_sink=attention_sink,
)