ejkernel.kernels._triton.ragged_decode_attention._triton_impl_fwd#
- ejkernel.kernels._triton.ragged_decode_attention._triton_impl_fwd.inner_decode_triton(query_tensor: Array, key_tensor: Array, value_tensor: Array, sequence_start: Array, sequence_end: Array, softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jax.jaxlib._jax.Array | None = None) Array[source]#
- GPU/Triton inner decode. Mirrors your inner_decode_tpu:
group Q into [B, HKV, G, D]
K/V -> [B, HKV, S, D]
run MQA ragged decode per kv head to get [B, G, D]
reshape back to [B, HQ, D]
- ejkernel.kernels._triton.ragged_decode_attention._triton_impl_fwd.ragged_decode_mqa_triton(q: Array, k: Array, v: Array, starts: Array, ends: Array, softmax_scale: float, block_size: int, sliding_window: tuple[int, int] | None, logits_soft_cap: float, aux: jax.jaxlib._jax.Array | None) Array[source]#