ejkernel.kernels._pallas.gpu.ragged_decode_attention._interface

ejkernel.kernels._pallas.gpu.ragged_decode_attention._interface#

ejkernel.kernels._pallas.gpu.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#

Performs attention decoding over ragged sequences using a GPU-optimized kernel.

This function serves as the public API for decoding attention across variable-length sequences (ragged) using head-blocked GPU kernels. It supports multi-head attention (MHA), multi-query attention (MQA), and grouped-query attention (GQA) layouts.

Parameters
  • query (chex.Array) – Query tensor of shape (batch_size, num_query_heads, head_dim).

  • key (chex.Array) – Key tensor of shape (batch_size, sequence_length, num_kv_heads, head_dim).

  • value (chex.Array) – Value tensor of shape (batch_size, sequence_length, num_kv_heads, head_dim).

  • sequence_start (chex.Array, optional) – Optional start indices of valid sequence ranges, shape (batch_size,).

  • sequence_end (chex.Array, optional) – Optional end indices of valid sequence ranges, shape (batch_size,).

  • softmax_scale (float, optional) – Optional scaling factor for the attention softmax. Defaults to 1 / sqrt(head_dim) if not provided.

  • block_size_heads (int) – Size of the head dimension block. Affects tiling for attention computation.

  • block_size_keys (int) – Size of the key block per thread block.

  • num_key_splits (int) – Number of splits (tiles) in the key dimension.

  • num_warps (int, optional) – Number of GPU warps per thread block.

  • num_stages (int) – Pipeline stages for kernel execution.

Returns

Output tensor of shape (batch_size, num_query_heads, head_dim) after attention is applied.

Return type

chex.Array

Raises
  • ValueError – If key and value have different head dimensions.

  • ValueError – If query heads are not divisible by the number of KV heads.