ejkernel.modules.operations.ring_attention#
Ring Attention module with automatic optimization.
This module implements Ring Attention, a distributed attention mechanism that enables efficient processing of extremely long sequences by distributing computation across multiple devices in a ring topology. Unlike standard attention which requires all KV pairs to fit in a single device’s memory, Ring Attention overlaps communication and computation through pipelining.
- Ring Attention is particularly valuable for:
Ultra-long sequence processing (100K+ tokens)
Training large language models with long contexts
Distributed inference scenarios
Memory-constrained environments requiring sequence parallelism
- Key Innovation:
Ring Attention partitions the KV pairs across devices and uses a ring-based communication pattern to stream KV blocks through each device. Each device: 1. Computes attention with its local KV block 2. Passes the KV block to the next device in the ring 3. Receives the next KV block from the previous device 4. Continues until all KV blocks have been processed
This achieves O(N) memory per device while maintaining O(N^2) computation.
- Mathematical Foundation:
For a sequence of length N split across D devices: - Each device holds N/D query tokens - KV pairs are rotated through the ring - Attention is computed incrementally: softmax_i = exp(QK_i^T) / sum_j(exp(QK_j^T)) - Running statistics (max, sum) are maintained for numerical stability
- Communication Pattern:
Device 0: KV_0 -> KV_1 -> … -> KV_{D-1} Device 1: KV_1 -> KV_2 -> … -> KV_0 Device i: KV_i -> KV_{i+1} -> … -> KV_{i-1} (mod D)
- Performance Characteristics:
Memory: O(N/D) per device vs O(N) for standard attention
Computation: O(N^2/D) per device (same asymptotic cost)
Communication: O(N) per device (bandwidth-efficient with overlap)
References
Liu et al., “Ring Attention with Blockwise Transformers for Near-Infinite Context” https://arxiv.org/abs/2310.01889
- class ejkernel.modules.operations.ring_attention.RingAttention[source]#
Bases:
Kernel[RingAttentionConfig,Array]Ring Attention with custom optimization logic.
Implements distributed attention using ring communication topology for processing ultra-long sequences across multiple devices with memory efficiency.
- Features:
Distributed KV processing via ring communication
Overlapped computation and communication for efficiency
Causal and non-causal attention support
Sliding window attention for local patterns
Attention sink mechanism for long-context stability
Configurable chunk sizes for memory-computation tradeoffs
Gradient checkpointing support for training
Multiple platform support (Triton/Pallas/CUDA/XLA)
- The implementation maintains numerical stability through:
Online softmax with running max/sum statistics
Logit soft capping to prevent overflow
Float32 logit accumulation (configurable)
- Typical Usage Patterns:
Multi-GPU training with sequence parallelism
Long-context inference on multiple devices
Blockwise transformer architectures
- candidate_cfgs(inv: Invocation[RingAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates configurations optimized for different sequence lengths and device counts, balancing chunk size with communication overhead.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Ring attention performance is sensitive to chunk sizes relative to sequence length per device and communication bandwidth.
- candidate_cfgs_shard_map_tpu(inv: Invocation[RingAttentionConfig, Array])#
Generate TPU-optimized candidate configurations for autotuning.
TPU/Pallas kernels benefit from larger blocks for ring attention.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Iterable of TPU-optimized candidate configurations
- candidate_cfgs_tpu(inv: Invocation[RingAttentionConfig, Array])[source]#
Generate TPU-optimized candidate configurations for autotuning.
TPU/Pallas kernels benefit from larger blocks for ring attention.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Iterable of TPU-optimized candidate configurations
- create_shard_map_wrapper(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper specifically for ring attention.
Ring attention naturally works with distributed execution, using collective communication across devices.
- Parameters
query – Query tensor to be sharded
key – Key tensor to be sharded
value – Value tensor to be sharded
q_segment_ids – Optional query segment IDs
kv_segment_ids – Optional KV segment IDs
softmax_aux – Optional attention sink logits
bias – Optional bias tensor
mask_builder – Optional custom mask builder function
sliding_window – Window size for local attention
chunk_size – Chunk size for chunked causal attention
causal – Whether to use causal masking
logits_soft_cap – Soft cap value for attention logits
softmax_scale – Scaling factor for attention scores
axis_name – Axis name for ring communication
fused_backward – Whether to use fused backward kernel
platform – Target platform
cfg – Kernel configuration object
mesh – JAX device mesh
in_specs – Input partition specs
out_specs – Output partition spec
check_vma – Check for virtual memory access
- Returns
Tuple of (shard_map_fn, call_args)
- get_impl(cfg: RingAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend preferences
- Returns
Callable kernel implementation for ring attention
- Raises
ValueError – If no matching implementation is found for the configuration
- heuristic_cfg(inv: Invocation[RingAttentionConfig, Array]) RingAttentionConfig[source]#
Provide default configuration optimized for ring attention.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default RingAttentionConfig with block sizes balanced for communication and computation overlap in distributed settings
- run(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, *, cfg: RingAttentionConfig) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute ring attention with distributed KV processing.
Computes attention across devices using ring communication pattern, enabling efficient processing of sequences that don’t fit in single device memory.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed)
value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed)
q_segment_ids – Optional query segment IDs [batch, seq_len_q]
kv_segment_ids – Optional KV segment IDs [batch, seq_len_k]
softmax_aux – Optional attention sink logits for long-context stability
bias – Optional attention bias [batch, num_heads, seq_len_q, seq_len_k]
mask_builder – Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask
sliding_window – Window size for local attention (int or (left, right) tuple)
chunk_size – Chunk size for chunked causal attention (Llama4 style)
causal – Whether to use causal masking
logits_soft_cap – Soft cap value to bound attention logits
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
axis_name – Name of the axis for collective operations (required for multi-device)
fused_backward – Whether to use fused backward kernel
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Attention output [batch, seq_len_q, num_heads, head_dim]
Note
Ring attention requires proper device mesh setup with the specified axis_name. Each device processes a slice of the sequence and communicates KV pairs through the ring topology.
Example
>>> >>> mesh = jax.sharding.Mesh(devices, axis_names=['sp']) >>> >>> >>> with mesh: ... out = ring_attention(q, k, v, axis_name='sp')
- ejkernel.modules.operations.ring_attention.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], softmax_aux: Float[Array, 'num_heads num_sinks'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, /, *, mask_info: MaskInfo | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute ring attention with automatic optimization.
Ring attention distributes attention computation across devices in a ring topology, enabling efficient processing of very long sequences through communication-efficient parallelization.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim]
softmax_aux – Optional attention sink logits for long-context stability
bias – Optional attention bias tensor
mask_info – Optional MaskInfo containing attention mask and/or segment IDs
mask_builder – Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask
sliding_window – Window size for local attention (int or (left, right) tuple)
chunk_size – Chunk size for chunked causal attention (Llama4 style)
causal – Whether to use causal masking
logits_soft_cap – Soft capping value for logits
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
axis_name – Name of the axis for collective operations
fused_backward – Whether to use fused backward kernel
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional configuration override
mesh – JAX device mesh for shard_map execution (optional)
in_specs – Input partition specs for shard_map (optional)
out_specs – Output partition spec for shard_map (optional)
- Returns
Attention output with same shape as query
Example
>>> >>> out = ring_attention(query, key, value, causal=True, axis_name="sp") >>> >>> >>> out = ring_attention( ... query, key, value, ... causal=True, ... sliding_window=1024, ... axis_name="sp", ... ) >>> >>> >>> out = ring_attention(..., platform="triton")