ejkernel.kernels._pallas.gpu.ragged_decode_attention._interface#
- ejkernel.kernels._pallas.gpu.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#
Performs attention decoding over ragged sequences using a GPU-optimized kernel.
This function serves as the public API for decoding attention across variable-length sequences (ragged) using head-blocked GPU kernels. It supports multi-head attention (MHA), multi-query attention (MQA), and grouped-query attention (GQA) layouts.
- Parameters
query (chex.Array) – Query tensor of shape (batch_size, num_query_heads, head_dim).
key (chex.Array) – Key tensor of shape (batch_size, sequence_length, num_kv_heads, head_dim).
value (chex.Array) – Value tensor of shape (batch_size, sequence_length, num_kv_heads, head_dim).
sequence_start (chex.Array, optional) – Optional start indices of valid sequence ranges, shape (batch_size,).
sequence_end (chex.Array, optional) – Optional end indices of valid sequence ranges, shape (batch_size,).
softmax_scale (float, optional) – Optional scaling factor for the attention softmax. Defaults to 1 / sqrt(head_dim) if not provided.
block_size_heads (int) – Size of the head dimension block. Affects tiling for attention computation.
block_size_keys (int) – Size of the key block per thread block.
num_key_splits (int) – Number of splits (tiles) in the key dimension.
num_warps (int, optional) – Number of GPU warps per thread block.
num_stages (int) – Pipeline stages for kernel execution.
- Returns
Output tensor of shape (batch_size, num_query_heads, head_dim) after attention is applied.
- Return type
chex.Array
- Raises
ValueError – If key and value have different head dimensions.
ValueError – If query heads are not divisible by the number of KV heads.