Source code for ejkernel.kernels._xla.ragged_decode_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ragged decode attention interface for variable-length decoding.

This module provides the public API for attention during decoding with
variable-length sequences. Supports MQA/GQA configurations with sliding
window and attention sink capabilities.
"""

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ejkernel.ops import FwdParams

from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_fwd import ragged_decode_mqa_xla


[docs]@kernel_registry.register("ragged_decode_attention", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def ragged_decode_attention( query: Float[Array, "batch num_q_heads head_dim"], key: Float[Array, "batch seq_len num_kv_heads head_dim"], value: Float[Array, "batch seq_len num_kv_heads head_dim"], sequence_start: Int[Array, "batch"], sequence_end: Int[Array, "batch"], softmax_scale: float | None = None, fwd_params: FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: Float[Array, "num_sinks"] | None = None, ) -> Float[Array, "batch num_q_heads head_dim"]: """Ragged MQA/GQA decoding with standard XLA operations. This function implements ragged attention for decoding scenarios where different sequences in a batch have different lengths. It supports Multi-Query Attention (MQA) and Grouped-Query Attention (GQA). Args: query: Query tensor of shape [batch, num_heads, head_dim]. Represents the current decoding position (single token per sequence). key: Key tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous tokens in the KV cache. value: Value tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous token values. sequence_start: int32 array of shape [batch]. Start indices for each sequence in the batch. sequence_end: int32 array of shape [batch]. End indices (exclusive) for each sequence in the batch. softmax_scale: Optional scale for attention scores. If None, uses 1/sqrt(head_dim). sliding_window: Optional (left, right) window sizes for local attention. Limits attention to tokens within the window around the query position. None means full attention to all valid tokens. logits_soft_cap: Optional soft capping value for attention logits. Applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large. softmax_aux: Optional auxiliary logits for attention sinks. Shape [num_heads, num_sinks] or [num_sinks]. Concatenated to attention logits before softmax to create attention sink behavior (e.g., always attending to initial tokens regardless of their position). Returns: Output tensor of shape [batch, num_heads, head_dim] after attention. Examples: >>> import jax.numpy as jnp >>> batch, seq_len, num_heads, head_dim = 2, 512, 8, 64 >>> >>> >>> sequence_start = jnp.array([0, 0], dtype=jnp.int32) >>> sequence_end = jnp.array([384, 512], dtype=jnp.int32) >>> >>> query = jax.random.normal(jax.random.key(0), (batch, num_heads, head_dim)) >>> key = jax.random.normal(jax.random.key(1), (batch, seq_len, num_heads, head_dim)) >>> value = jax.random.normal(jax.random.key(2), (batch, seq_len, num_heads, head_dim)) >>> >>> >>> output = ragged_decode_xla( ... query, key, value, ... sequence_start, sequence_end, ... softmax_scale=1.0 / jnp.sqrt(head_dim) ... ) >>> >>> >>> sinks = jnp.ones((num_heads, 4)) * 5.0 >>> output = ragged_decode_xla( ... query, key, value, ... sequence_start, sequence_end, ... softmax_scale=1.0 / jnp.sqrt(head_dim), ... sliding_window=(256, 256), ... logits_soft_cap=30.0, ... softmax_aux=sinks ... ) Notes: - This is a pure XLA/JAX implementation suitable for CPU/GPU/TPU - For TPU with Pallas optimization, use ragged_decode_attention instead - Supports both MQA (num_kv_heads=1) and GQA (num_kv_heads < num_heads) - Query position is assumed to be at sequence_end - 1 (current decode position) """ return ragged_decode_mqa_xla( query=query, key=key, value=value, sequence_start=sequence_start, sequence_end=sequence_end, softmax_scale=softmax_scale, fwd_params=fwd_params, sliding_window=sliding_window, logits_soft_cap=logits_soft_cap, softmax_aux=softmax_aux, )