ejkernel.kernels._pallas.tpu.ragged_decode_attention._interface#
- ejkernel.kernels._pallas.tpu.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#
Ragged MQA decoding entry point with TPU-accelerated Flash Attention.
- Parameters
query – Query tensor of shape [batch, num_heads, head_dim].
key – Key tensor of shape [batch, seq_len, num_kv_heads, head_dim].
value – Value tensor of shape [batch, seq_len, num_kv_heads, head_dim].
sequence_start – int32 array of shape [batch], start indices of sequences.
sequence_end – int32 array of shape [batch], end indices of sequences.
softmax_scale – Optional scale for attention logits. Default is 1.
block_size – Block size used for kernel tiling. Default is 256.
sliding_window – Optional (left, right) sliding window sizes. If specified, limits attention to tokens within the window. None means full attention.
logits_soft_cap – Optional soft capping value for attention logits. Applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap).
softmax_aux – Optional auxiliary logits for attention sinks. Shape [num_heads, num_sinks] or [num_sinks]. Concatenated to attention logits before softmax to create attention sink behavior.
- Returns
Output tensor of shape [batch, num_heads, head_dim] after attention decoding.