ejkernel.modules.operations.ring_attention#

Ring Attention module with automatic optimization.

This module implements Ring Attention, a distributed attention mechanism that enables efficient processing of extremely long sequences by distributing computation across multiple devices in a ring topology. Unlike standard attention which requires all KV pairs to fit in a single device’s memory, Ring Attention overlaps communication and computation through pipelining.

Ring Attention is particularly valuable for:
  • Ultra-long sequence processing (100K+ tokens)

  • Training large language models with long contexts

  • Distributed inference scenarios

  • Memory-constrained environments requiring sequence parallelism

Key Innovation:

Ring Attention partitions the KV pairs across devices and uses a ring-based communication pattern to stream KV blocks through each device. Each device: 1. Computes attention with its local KV block 2. Passes the KV block to the next device in the ring 3. Receives the next KV block from the previous device 4. Continues until all KV blocks have been processed

This achieves O(N) memory per device while maintaining O(N^2) computation.

Mathematical Foundation:

For a sequence of length N split across D devices: - Each device holds N/D query tokens - KV pairs are rotated through the ring - Attention is computed incrementally: softmax_i = exp(QK_i^T) / sum_j(exp(QK_j^T)) - Running statistics (max, sum) are maintained for numerical stability

Communication Pattern:

Device 0: KV_0 -> KV_1 -> … -> KV_{D-1} Device 1: KV_1 -> KV_2 -> … -> KV_0 Device i: KV_i -> KV_{i+1} -> … -> KV_{i-1} (mod D)

Performance Characteristics:
  • Memory: O(N/D) per device vs O(N) for standard attention

  • Computation: O(N^2/D) per device (same asymptotic cost)

  • Communication: O(N) per device (bandwidth-efficient with overlap)

References

Liu et al., “Ring Attention with Blockwise Transformers for Near-Infinite Context” https://arxiv.org/abs/2310.01889

class ejkernel.modules.operations.ring_attention.RingAttention[source]#

Bases: Kernel[RingAttentionConfig, Array]

Ring Attention with custom optimization logic.

Implements distributed attention using ring communication topology for processing ultra-long sequences across multiple devices with memory efficiency.

Features:
  • Distributed KV processing via ring communication

  • Overlapped computation and communication for efficiency

  • Causal and non-causal attention support

  • Sliding window attention for local patterns

  • Attention sink mechanism for long-context stability

  • Configurable chunk sizes for memory-computation tradeoffs

  • Gradient checkpointing support for training

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

The implementation maintains numerical stability through:
  • Online softmax with running max/sum statistics

  • Logit soft capping to prevent overflow

  • Float32 logit accumulation (configurable)

Typical Usage Patterns:
  • Multi-GPU training with sequence parallelism

  • Long-context inference on multiple devices

  • Blockwise transformer architectures

candidate_cfgs(inv: Invocation[RingAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates configurations optimized for different sequence lengths and device counts, balancing chunk size with communication overhead.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Ring attention performance is sensitive to chunk sizes relative to sequence length per device and communication bandwidth.

candidate_cfgs_shard_map_tpu(inv: Invocation[RingAttentionConfig, Array])#

Generate TPU-optimized candidate configurations for autotuning.

TPU/Pallas kernels benefit from larger blocks for ring attention.

Parameters

inv – Invocation object with arguments and metadata

Returns

Iterable of TPU-optimized candidate configurations

candidate_cfgs_tpu(inv: Invocation[RingAttentionConfig, Array])[source]#

Generate TPU-optimized candidate configurations for autotuning.

TPU/Pallas kernels benefit from larger blocks for ring attention.

Parameters

inv – Invocation object with arguments and metadata

Returns

Iterable of TPU-optimized candidate configurations

create_shard_map_wrapper(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper specifically for ring attention.

Ring attention naturally works with distributed execution, using collective communication across devices.

Parameters
  • query – Query tensor to be sharded

  • key – Key tensor to be sharded

  • value – Value tensor to be sharded

  • q_segment_ids – Optional query segment IDs

  • kv_segment_ids – Optional KV segment IDs

  • softmax_aux – Optional attention sink logits

  • bias – Optional bias tensor

  • mask_builder – Optional custom mask builder function

  • sliding_window – Window size for local attention

  • chunk_size – Chunk size for chunked causal attention

  • causal – Whether to use causal masking

  • logits_soft_cap – Soft cap value for attention logits

  • softmax_scale – Scaling factor for attention scores

  • axis_name – Axis name for ring communication

  • fused_backward – Whether to use fused backward kernel

  • platform – Target platform

  • cfg – Kernel configuration object

  • mesh – JAX device mesh

  • in_specs – Input partition specs

  • out_specs – Output partition spec

  • check_vma – Check for virtual memory access

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: RingAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for ring attention

Raises

ValueError – If no matching implementation is found for the configuration

heuristic_cfg(inv: Invocation[RingAttentionConfig, Array]) RingAttentionConfig[source]#

Provide default configuration optimized for ring attention.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default RingAttentionConfig with block sizes balanced for communication and computation overlap in distributed settings

run(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, *, cfg: RingAttentionConfig) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute ring attention with distributed KV processing.

Computes attention across devices using ring communication pattern, enabling efficient processing of sequences that don’t fit in single device memory.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed)

  • value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed)

  • q_segment_ids – Optional query segment IDs [batch, seq_len_q]

  • kv_segment_ids – Optional KV segment IDs [batch, seq_len_k]

  • softmax_aux – Optional attention sink logits for long-context stability

  • bias – Optional attention bias [batch, num_heads, seq_len_q, seq_len_k]

  • mask_builder – Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • chunk_size – Chunk size for chunked causal attention (Llama4 style)

  • causal – Whether to use causal masking

  • logits_soft_cap – Soft cap value to bound attention logits

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • axis_name – Name of the axis for collective operations (required for multi-device)

  • fused_backward – Whether to use fused backward kernel

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Attention output [batch, seq_len_q, num_heads, head_dim]

Note

Ring attention requires proper device mesh setup with the specified axis_name. Each device processes a slice of the sequence and communicates KV pairs through the ring topology.

Example

>>>
>>> mesh = jax.sharding.Mesh(devices, axis_names=['sp'])
>>>
>>>
>>> with mesh:
...     out = ring_attention(q, k, v, axis_name='sp')
ejkernel.modules.operations.ring_attention.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, /, *, mask_info: MaskInfo | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute ring attention with automatic optimization.

Ring attention distributes attention computation across devices in a ring topology, enabling efficient processing of very long sequences through communication-efficient parallelization.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • softmax_aux – Optional attention sink logits for long-context stability

  • bias – Optional attention bias tensor

  • mask_info – Optional MaskInfo containing attention mask and/or segment IDs

  • mask_builder – Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • chunk_size – Chunk size for chunked causal attention (Llama4 style)

  • causal – Whether to use causal masking

  • logits_soft_cap – Soft capping value for logits

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • axis_name – Name of the axis for collective operations

  • fused_backward – Whether to use fused backward kernel

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional configuration override

  • mesh – JAX device mesh for shard_map execution (optional)

  • in_specs – Input partition specs for shard_map (optional)

  • out_specs – Output partition spec for shard_map (optional)

Returns

Attention output with same shape as query

Example

>>>
>>> out = ring_attention(query, key, value, causal=True, axis_name="sp")
>>>
>>>
>>> out = ring_attention(
...     query, key, value,
...     causal=True,
...     sliding_window=1024,
...     axis_name="sp",
... )
>>>
>>>
>>> out = ring_attention(..., platform="triton")