Source code for ejkernel.kernels._triton.unified_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""vLLM-style unified (paged) attention implemented in Triton.

This is a JAX/Triton port of vLLM's `triton_unified_attention.py`, adapted to
ejkernel's `triton_call` interface.

Core inputs:
- `queries`: packed ragged queries, shape `[total_tokens, num_q_heads, head_dim]`
- `key_cache`/`value_cache`: paged KV cache, shape `[num_blocks, block_size, num_kv_heads, head_dim]`
- `query_start_loc`: cumulative query offsets, shape `[num_seqs + 1]` (int32)
- `kv_lens`: KV lengths per sequence, shape `[num_seqs]` (int32)
- `block_tables`: mapping `[num_seqs, max_blocks_per_seq]` (int32)

Supported features (inference-only):
- causal masking (required)
- optional sliding window via `sliding_window` (window length)
- optional logit softcap (`logits_soft_cap`)
- optional attention sink (`attention_sink`): contributes to softmax normalizer only
- optional ALiBi slopes (`alibi_slopes`)
- optional query-query bias (`qq_bias`) for TreeAttention-like decode
"""

from __future__ import annotations

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int32

from ..._registry import Backend, Platform, kernel_registry
from ._triton_impl_fwd import unified_attention_triton


[docs]@kernel_registry.register("unified_attention", Platform.TRITON, Backend.GPU) @jaxtyping.jaxtyped(typechecker=beartype) def unified_attention( queries: Float[Array, "total_tokens num_q_heads head_dim"], key_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"], value_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"], kv_lens: Int32[Array, "num_seqs"], block_tables: Int32[Array, "num_seqs max_blocks_per_seq"], query_start_loc: Int32[Array, "num_seqs_plus_1"], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, alibi_slopes: Float[Array, "num_q_heads"] | None = None, qq_bias: Float[Array, "num_query_tokens num_query_tokens"] | None = None, attention_sink: Float[Array, "num_q_heads"] | None = None, num_warps: int | None = None, num_stages: int | None = None, ) -> Float[Array, "total_tokens num_q_heads head_dim"]: return unified_attention_triton( queries=queries, key_cache=key_cache, value_cache=value_cache, block_tables=block_tables, kv_lens=kv_lens, query_start_loc=query_start_loc, softmax_scale=softmax_scale, causal=causal, sliding_window=sliding_window, logits_soft_cap=logits_soft_cap, seq_threshold_3d=seq_threshold_3d, num_par_softmax_segments=num_par_softmax_segments, alibi_slopes=alibi_slopes, qq_bias=qq_bias, attention_sink=attention_sink, num_warps=num_warps, num_stages=num_stages, )