ejkernel.kernels._xla.ragged_decode_attention._interface

ejkernel.kernels._xla.ragged_decode_attention._interface#

Ragged decode attention interface for variable-length decoding.

This module provides the public API for attention during decoding with variable-length sequences. Supports MQA/GQA configurations with sliding window and attention sink capabilities.

ejkernel.kernels._xla.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#

Ragged MQA/GQA decoding with standard XLA operations.

This function implements ragged attention for decoding scenarios where different sequences in a batch have different lengths. It supports Multi-Query Attention (MQA) and Grouped-Query Attention (GQA).

Parameters
  • query – Query tensor of shape [batch, num_heads, head_dim]. Represents the current decoding position (single token per sequence).

  • key – Key tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous tokens in the KV cache.

  • value – Value tensor of shape [batch, seq_len, num_kv_heads, head_dim]. Contains all previous token values.

  • sequence_start – int32 array of shape [batch]. Start indices for each sequence in the batch.

  • sequence_end – int32 array of shape [batch]. End indices (exclusive) for each sequence in the batch.

  • softmax_scale – Optional scale for attention scores. If None, uses 1/sqrt(head_dim).

  • sliding_window – Optional (left, right) window sizes for local attention. Limits attention to tokens within the window around the query position. None means full attention to all valid tokens.

  • logits_soft_cap – Optional soft capping value for attention logits. Applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large.

  • softmax_aux – Optional auxiliary logits for attention sinks. Shape [num_heads, num_sinks] or [num_sinks]. Concatenated to attention logits before softmax to create attention sink behavior (e.g., always attending to initial tokens regardless of their position).

Returns

Output tensor of shape [batch, num_heads, head_dim] after attention.

Examples

>>> import jax.numpy as jnp
>>> batch, seq_len, num_heads, head_dim = 2, 512, 8, 64
>>>
>>>
>>> sequence_start = jnp.array([0, 0], dtype=jnp.int32)
>>> sequence_end = jnp.array([384, 512], dtype=jnp.int32)
>>>
>>> query = jax.random.normal(jax.random.key(0), (batch, num_heads, head_dim))
>>> key = jax.random.normal(jax.random.key(1), (batch, seq_len, num_heads, head_dim))
>>> value = jax.random.normal(jax.random.key(2), (batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> output = ragged_decode_xla(
...     query, key, value,
...     sequence_start, sequence_end,
...     softmax_scale=1.0 / jnp.sqrt(head_dim)
... )
>>>
>>>
>>> sinks = jnp.ones((num_heads, 4)) * 5.0
>>> output = ragged_decode_xla(
...     query, key, value,
...     sequence_start, sequence_end,
...     softmax_scale=1.0 / jnp.sqrt(head_dim),
...     sliding_window=(256, 256),
...     logits_soft_cap=30.0,
...     softmax_aux=sinks
... )

Notes

  • This is a pure XLA/JAX implementation suitable for CPU/GPU/TPU

  • For TPU with Pallas optimization, use ragged_decode_attention instead

  • Supports both MQA (num_kv_heads=1) and GQA (num_kv_heads < num_heads)

  • Query position is assumed to be at sequence_end - 1 (current decode position)