ejkernel.kernels._xla.ragged_decode_attention._interface#
Ragged decode attention interface for variable-length decoding.
This module provides the public API for attention during decoding with variable-length sequences. Supports MQA/GQA configurations with sliding window and attention sink capabilities.
- ejkernel.kernels._xla.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#
Ragged MQA/GQA decoding with standard XLA operations.
This function implements ragged attention for decoding scenarios where different sequences in a batch have different lengths. It supports Multi-Query Attention (MQA) and Grouped-Query Attention (GQA).
- Parameters
query – Query tensor of shape [batch, num_heads, head_dim]. Represents the current decoding position (single token per sequence).
key – Key tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous tokens in the KV cache.
value – Value tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous token values.
sequence_start – int32 array of shape [batch]. Start indices for each sequence in the batch.
sequence_end – int32 array of shape [batch]. End indices (exclusive) for each sequence in the batch.
softmax_scale – Optional scale for attention scores. If None, uses 1/sqrt(head_dim).
sliding_window – Optional (left, right) window sizes for local attention. Limits attention to tokens within the window around the query position. None means full attention to all valid tokens.
logits_soft_cap – Optional soft capping value for attention logits. Applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large.
softmax_aux – Optional auxiliary logits for attention sinks. Shape [num_heads, num_sinks] or [num_sinks]. Concatenated to attention logits before softmax to create attention sink behavior (e.g., always attending to initial tokens regardless of their position).
- Returns
Output tensor of shape [batch, num_heads, head_dim] after attention.
Examples
>>> import jax.numpy as jnp >>> batch, seq_len, num_heads, head_dim = 2, 512, 8, 64 >>> >>> >>> sequence_start = jnp.array([0, 0], dtype=jnp.int32) >>> sequence_end = jnp.array([384, 512], dtype=jnp.int32) >>> >>> query = jax.random.normal(jax.random.key(0), (batch, num_heads, head_dim)) >>> key = jax.random.normal(jax.random.key(1), (batch, seq_len, num_heads, head_dim)) >>> value = jax.random.normal(jax.random.key(2), (batch, seq_len, num_heads, head_dim)) >>> >>> >>> output = ragged_decode_xla( ... query, key, value, ... sequence_start, sequence_end, ... softmax_scale=1.0 / jnp.sqrt(head_dim) ... ) >>> >>> >>> sinks = jnp.ones((num_heads, 4)) * 5.0 >>> output = ragged_decode_xla( ... query, key, value, ... sequence_start, sequence_end, ... softmax_scale=1.0 / jnp.sqrt(head_dim), ... sliding_window=(256, 256), ... logits_soft_cap=30.0, ... softmax_aux=sinks ... )
Notes
This is a pure XLA/JAX implementation suitable for CPU/GPU/TPU
For TPU with Pallas optimization, use ragged_decode_attention instead
Supports both MQA (num_kv_heads=1) and GQA (num_kv_heads < num_heads)
Query position is assumed to be at sequence_end - 1 (current decode position)