ejkernel.kernels._pallas.tpu.ring_attention._ring_splash#

Ring Attention implementation using Splash Attention kernels.

This module provides ring attention by wrapping JAX’s splash attention kernels with a ring communication topology for distributed attention computation.

class ejkernel.kernels._pallas.tpu.ring_attention._ring_splash.RingSplashAttentionKernel(fwd_mask_info: MaskInfo, dkv_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, ring_axis: str = 'sp', **kwargs)[source]#

Bases: object

tree_flatten()[source]#
classmethod tree_unflatten(aux_data, children)[source]#
class ejkernel.kernels._pallas.tpu.ring_attention._ring_splash.SegmentIds(q: jax.Array, kv: jax.Array)[source]#

Bases: NamedTuple

SegmentIds for Q and KV sequences.

kv: Array#

Alias for field number 1

q: Array#

Alias for field number 0

ejkernel.kernels._pallas.tpu.ring_attention._ring_splash.make_ring_attention(mask: numpy.ndarray | jax.jaxlib._jax.Array | ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.Mask, *, block_sizes: ejkernel.kernels._pallas.tpu.blocksparse_attention._kernel.BlockSizes | None = None, is_mqa: bool = False, mask_value: float = -2.381976426469702e+38, logits_soft_cap: float | None = None, ring_axis: str = 'sp', q_seq_shards: int = 1) RingSplashAttentionKernel[source]#
ejkernel.kernels._pallas.tpu.ring_attention._ring_splash.ring_splash_attention(fwd_mask_info: MaskInfo, dkv_mask_info: ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo | None, q: Array, k: Array, v: Array, segment_ids: ejkernel.kernels._pallas.tpu.ring_attention._ring_splash.SegmentIds | None = None, sinks: jax.jaxlib._jax.Array | None = None, *, is_mqa: bool, block_sizes: BlockSizes, mask_value: float = -2.381976426469702e+38, mask_function: collections.abc.Callable[[...], jax.jaxlib._jax.Array] | None = None, logits_soft_cap: float | None = None, ring_axis: str = 'sp', causal: bool = False) Array[source]#