ejkernel.kernels._triton.ragged_decode_attention._interface#
- ejkernel.kernels._triton.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#
Ragged decode attention (GPU/Triton), functionally matching the TPU/Pallas version.
- Parameters
query – [B, HQ, D]
key – [B, S, HKV, D]
value – [B, S, HKV, D]
sequence_start – [B] int32 (inclusive)
sequence_end – [B] int32 (exclusive)
softmax_scale – logits scale
block_size – tile size along sequence axis
sliding_window – optional (left, right) window; None => full attention
logits_soft_cap – optional tanh-cap for logits
softmax_aux – optional sinks: - [HKV, NS] (per kv head), or - [NS] (broadcast to each kv head)
- Returns
[B, HQ, D]
- Return type
Output