Source code for ejkernel.kernels._xla.unified_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified attention interface for paged KV cache with mixed workloads.

This module provides the public API for unified attention that handles
ragged batches with paged key-value caches. Supports sliding window,
ALiBi slopes, and attention sink features.
"""

from __future__ import annotations

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int32

from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_fwd import _unified_attention_fwd


[docs]@kernel_registry.register("unified_attention", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def unified_attention( queries: Float[Array, "total_tokens num_q_heads head_dim"], key_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"], value_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"], kv_lens: Int32[Array, "num_seqs"], block_tables: Int32[Array, "num_seqs max_blocks_per_seq"], query_start_loc: Int32[Array, "num_seqs_plus_1"], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, alibi_slopes: Float[Array, "num_q_heads"] | None = None, qq_bias: Float[Array, "num_query_tokens num_query_tokens"] | None = None, attention_sink: Float[Array, "num_q_heads"] | None = None, num_warps: int | None = None, num_stages: int | None = None, ) -> Float[Array, "total_tokens num_q_heads head_dim"]: del seq_threshold_3d, num_par_softmax_segments, num_warps, num_stages if not causal: raise NotImplementedError("unified_attention (XLA) only supports causal attention.") if softmax_scale is None: softmax_scale = queries.shape[-1] ** -0.5 return _unified_attention_fwd( queries=queries, key_cache=key_cache, value_cache=value_cache, kv_lens=kv_lens, block_tables=block_tables, query_start_loc=query_start_loc, softmax_scale=float(softmax_scale), sliding_window=sliding_window, logits_soft_cap=logits_soft_cap, alibi_slopes=alibi_slopes, qq_bias=qq_bias, attention_sink=attention_sink, )