ejkernel.kernels._triton.ring_attention._ring_kernel#
Ring Flash Attention Kernel - wraps Triton flash attention for distributed ring topology.
- class ejkernel.kernels._triton.ring_attention._ring_kernel.RingFlashResiduals(q: jax.Array, k: jax.Array, v: jax.Array, bias: jax.Array | None, attention_mask: jax.Array | None, o: jax.Array, lse: jax.Array, dropout_seed: int | None)[source]#
Bases:
NamedTupleResiduals saved from forward pass for backward computation.
- attention_mask: jax.jaxlib._jax.Array | None#
Alias for field number 4
- bias: jax.jaxlib._jax.Array | None#
Alias for field number 3
- dropout_seed: int | None#
Alias for field number 7
- k: Array#
Alias for field number 1
- lse: Array#
Alias for field number 6
- o: Array#
Alias for field number 5
- q: Array#
Alias for field number 0
- v: Array#
Alias for field number 2