ejkernel.kernels._triton.ring_attention._ring_kernel#

Ring Flash Attention Kernel - wraps Triton flash attention for distributed ring topology.

class ejkernel.kernels._triton.ring_attention._ring_kernel.RingFlashResiduals(q: jax.Array, k: jax.Array, v: jax.Array, bias: jax.Array | None, attention_mask: jax.Array | None, o: jax.Array, lse: jax.Array, dropout_seed: int | None)[source]#

Bases: NamedTuple

Residuals saved from forward pass for backward computation.

attention_mask: jax.jaxlib._jax.Array | None#

Alias for field number 4

bias: jax.jaxlib._jax.Array | None#

Alias for field number 3

dropout_seed: int | None#

Alias for field number 7

k: Array#

Alias for field number 1

lse: Array#

Alias for field number 6

o: Array#

Alias for field number 5

q: Array#

Alias for field number 0

v: Array#

Alias for field number 2