ejkernel.kernels._triton.ragged_decode_attention._interface

ejkernel.kernels._triton.ragged_decode_attention._interface#

ejkernel.kernels._triton.ragged_decode_attention._interface.ragged_decode_attention(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#

Ragged decode attention (GPU/Triton), functionally matching the TPU/Pallas version.

Parameters
  • query – [B, HQ, D]

  • key – [B, S, HKV, D]

  • value – [B, S, HKV, D]

  • sequence_start – [B] int32 (inclusive)

  • sequence_end – [B] int32 (exclusive)

  • softmax_scale – logits scale

  • block_size – tile size along sequence axis

  • sliding_window – optional (left, right) window; None => full attention

  • logits_soft_cap – optional tanh-cap for logits

  • softmax_aux – optional sinks: - [HKV, NS] (per kv head), or - [NS] (broadcast to each kv head)

Returns

[B, HQ, D]

Return type

Output