ejkernel.kernels._pallas.tpu.ragged_decode_attention._interface

ejkernel.kernels._pallas.tpu.ragged_decode_attention._interface#

ejkernel.kernels._pallas.tpu.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#

Ragged MQA decoding entry point with TPU-accelerated Flash Attention.

Parameters
  • query – Query tensor of shape [batch, num_heads, head_dim].

  • key – Key tensor of shape [batch, seq_len, num_kv_heads, head_dim].

  • value – Value tensor of shape [batch, seq_len, num_kv_heads, head_dim].

  • sequence_start – int32 array of shape [batch], start indices of sequences.

  • sequence_end – int32 array of shape [batch], end indices of sequences.

  • softmax_scale – Optional scale for attention logits. Default is 1.

  • block_size – Block size used for kernel tiling. Default is 256.

  • sliding_window – Optional (left, right) sliding window sizes. If specified, limits attention to tokens within the window. None means full attention.

  • logits_soft_cap – Optional soft capping value for attention logits. Applies tanh-based soft capping: logits_soft_cap * tanh(logits / logits_soft_cap).

  • softmax_aux – Optional auxiliary logits for attention sinks. Shape [num_heads, num_sinks] or [num_sinks]. Concatenated to attention logits before softmax to create attention sink behavior.

Returns

Output tensor of shape [batch, num_heads, head_dim] after attention decoding.