Source code for ejkernel.modules.operations.ring_attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Ring Attention module with automatic optimization.

This module implements Ring Attention, a distributed attention mechanism that enables
efficient processing of extremely long sequences by distributing computation across
multiple devices in a ring topology. Unlike standard attention which requires all KV
pairs to fit in a single device's memory, Ring Attention overlaps communication and
computation through pipelining.

Ring Attention is particularly valuable for:
    - Ultra-long sequence processing (100K+ tokens)
    - Training large language models with long contexts
    - Distributed inference scenarios
    - Memory-constrained environments requiring sequence parallelism

Key Innovation:
    Ring Attention partitions the KV pairs across devices and uses a ring-based
    communication pattern to stream KV blocks through each device. Each device:
    1. Computes attention with its local KV block
    2. Passes the KV block to the next device in the ring
    3. Receives the next KV block from the previous device
    4. Continues until all KV blocks have been processed

    This achieves O(N) memory per device while maintaining O(N^2) computation.

Mathematical Foundation:
    For a sequence of length N split across D devices:
    - Each device holds N/D query tokens
    - KV pairs are rotated through the ring
    - Attention is computed incrementally: softmax_i = exp(QK_i^T) / sum_j(exp(QK_j^T))
    - Running statistics (max, sum) are maintained for numerical stability

Communication Pattern:
    Device 0: KV_0 -> KV_1 -> ... -> KV_{D-1}
    Device 1: KV_1 -> KV_2 -> ... -> KV_0
    Device i: KV_i -> KV_{i+1} -> ... -> KV_{i-1} (mod D)

Performance Characteristics:
    - Memory: O(N/D) per device vs O(N) for standard attention
    - Computation: O(N^2/D) per device (same asymptotic cost)
    - Communication: O(N) per device (bandwidth-efficient with overlap)

References:
    Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context"
    https://arxiv.org/abs/2310.01889
"""

from __future__ import annotations

import os
import typing
from typing import Literal

from jax import shard_map
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, Float, Int

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    BwdParams,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    FwdParams,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ejkernel.types.mask import MaskInfo

from ..base import detect_platform
from .configs import RingAttentionConfig

if typing.TYPE_CHECKING:
    from collections.abc import Callable

    from ejkernel.kernels._pallas.tpu.blocksparse_attention._masks import Mask


[docs]class RingAttention(Kernel[RingAttentionConfig, Array]): """Ring Attention with custom optimization logic. Implements distributed attention using ring communication topology for processing ultra-long sequences across multiple devices with memory efficiency. Features: - Distributed KV processing via ring communication - Overlapped computation and communication for efficiency - Causal and non-causal attention support - Sliding window attention for local patterns - Attention sink mechanism for long-context stability - Configurable chunk sizes for memory-computation tradeoffs - Gradient checkpointing support for training - Multiple platform support (Triton/Pallas/CUDA/XLA) The implementation maintains numerical stability through: - Online softmax with running max/sum statistics - Logit soft capping to prevent overflow - Float32 logit accumulation (configurable) Typical Usage Patterns: - Multi-GPU training with sequence parallelism - Long-context inference on multiple devices - Blockwise transformer architectures """ def __init__(self): """Initialize Ring Attention module. Sets up the kernel with the operation identifier for registry lookup and distributed execution management. """ super().__init__(op_id="ring_attention")
[docs] def create_shard_map_wrapper( self, query: Float[Array, "batch seq_len_q num_heads head_dim"], key: Float[Array, "batch seq_len_k num_kv_heads head_dim"], value: Float[Array, "batch seq_len_k num_kv_heads head_dim"], q_segment_ids: Int[Array, "batch seq_len_q"] | None = None, kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False, ): """Create a shard_map wrapper specifically for ring attention. Ring attention naturally works with distributed execution, using collective communication across devices. Args: query: Query tensor to be sharded key: Key tensor to be sharded value: Value tensor to be sharded q_segment_ids: Optional query segment IDs kv_segment_ids: Optional KV segment IDs softmax_aux: Optional attention sink logits bias: Optional bias tensor mask_builder: Optional custom mask builder function sliding_window: Window size for local attention chunk_size: Chunk size for chunked causal attention causal: Whether to use causal masking logits_soft_cap: Soft cap value for attention logits softmax_scale: Scaling factor for attention scores axis_name: Axis name for ring communication fused_backward: Whether to use fused backward kernel platform: Target platform cfg: Kernel configuration object mesh: JAX device mesh in_specs: Input partition specs out_specs: Output partition spec check_vma: Check for virtual memory access Returns: Tuple of (shard_map_fn, call_args) """ assert mesh is not None, "mesh must be provided for shard_map execution" assert in_specs is not None, "in_specs must be provided for shard_map execution" assert out_specs is not None, "out_specs must be provided for shard_map execution" def _wrapped_ring_attn( query: Float[Array, "batch seq_len_q num_heads head_dim"], key: Float[Array, "batch seq_len_k num_kv_heads head_dim"], value: Float[Array, "batch seq_len_k num_kv_heads head_dim"], softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None, q_segment_ids: Int[Array, "batch seq_len_q"] | None = None, kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: return self.run( query=query, key=key, value=value, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, softmax_aux=softmax_aux, bias=bias, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, logits_soft_cap=logits_soft_cap, softmax_scale=softmax_scale, axis_name=axis_name, fused_backward=fused_backward, platform=platform, cfg=cfg, ) call_args = ( query, key, value, softmax_aux, bias, q_segment_ids, kv_segment_ids, ) assert len(in_specs) == len(call_args), f"in_specs length {len(in_specs)} != call_args length {len(call_args)}" shard_map_fn = shard_map( _wrapped_ring_attn, mesh=mesh, in_specs=in_specs, out_specs=out_specs, check_vma=check_vma, ) return shard_map_fn, call_args
[docs] def get_impl(self, cfg: RingAttentionConfig): """Get kernel implementation from registry. Args: cfg: Configuration specifying platform and backend preferences Returns: Callable kernel implementation for ring attention Raises: ValueError: If no matching implementation is found for the configuration """ platform = detect_platform("ring_attention", cfg.platform) return kernel_registry.get("ring_attention", platform=platform, backend=cfg.backend)
[docs] def run( self, query: Float[Array, "batch seq_len_q num_heads head_dim"], key: Float[Array, "batch seq_len_k num_kv_heads head_dim"], value: Float[Array, "batch seq_len_k num_kv_heads head_dim"], q_segment_ids: Int[Array, "batch seq_len_q"] | None = None, kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, *, cfg: RingAttentionConfig, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: """Execute ring attention with distributed KV processing. Computes attention across devices using ring communication pattern, enabling efficient processing of sequences that don't fit in single device memory. Args: query: Query tensor [batch, seq_len_q, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed) value: Value tensor [batch, seq_len_k, num_kv_heads, head_dim] (distributed) q_segment_ids: Optional query segment IDs [batch, seq_len_q] kv_segment_ids: Optional KV segment IDs [batch, seq_len_k] softmax_aux: Optional attention sink logits for long-context stability bias: Optional attention bias [batch, num_heads, seq_len_q, seq_len_k] mask_builder: Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask sliding_window: Window size for local attention (int or (left, right) tuple) chunk_size: Chunk size for chunked causal attention (Llama4 style) causal: Whether to use causal masking logits_soft_cap: Soft cap value to bound attention logits softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) axis_name: Name of the axis for collective operations (required for multi-device) fused_backward: Whether to use fused backward kernel platform: Optional platform override ("triton", "pallas", "cuda", "xla") cfg: Kernel configuration object Returns: Attention output [batch, seq_len_q, num_heads, head_dim] Note: Ring attention requires proper device mesh setup with the specified axis_name. Each device processes a slice of the sequence and communicates KV pairs through the ring topology. Example: >>> >>> mesh = jax.sharding.Mesh(devices, axis_names=['sp']) >>> >>> >>> with mesh: ... out = ring_attention(q, k, v, axis_name='sp') """ if platform is not None: cfg = RingAttentionConfig( fwd_params=cfg.fwd_params, bwd_params=cfg.bwd_params, platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) return impl( query=query, key=key, value=value, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, softmax_aux=softmax_aux, bias=bias, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, logits_soft_cap=logits_soft_cap, softmax_scale=softmax_scale, axis_name=axis_name, fwd_params=cfg.fwd_params, bwd_params=cfg.bwd_params, fused_backward=fused_backward, )
[docs] def heuristic_cfg(self, inv: Invocation[RingAttentionConfig, Array]) -> RingAttentionConfig: """Provide default configuration optimized for ring attention. Args: inv: Invocation object containing arguments and metadata Returns: Default RingAttentionConfig with block sizes balanced for communication and computation overlap in distributed settings """ return RingAttentionConfig( fwd_params=FwdParams(q_blocksize=512, kv_blocksize=512, num_stages=2, num_warps=4), bwd_params=BwdParams(q_blocksize=512, kv_blocksize=512, num_stages=2, num_warps=4), platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[RingAttentionConfig, Array]): """Generate candidate configurations for autotuning. Creates configurations optimized for different sequence lengths and device counts, balancing chunk size with communication overhead. Args: inv: Invocation object containing arguments and metadata Returns: List of candidate configurations to benchmark during autotuning Note: Ring attention performance is sensitive to chunk sizes relative to sequence length per device and communication bandwidth. """ candidates = [] for block_q, block_k in [(128, 128), (256, 256), (512, 512)]: candidates.append( RingAttentionConfig( fwd_params=FwdParams(q_blocksize=block_q, kv_blocksize=block_k, num_stages=2, num_warps=4), bwd_params=BwdParams(q_blocksize=block_q, kv_blocksize=block_k, num_stages=2, num_warps=4), platform="auto", backend="any", ) ) return candidates
[docs] def candidate_cfgs_tpu(self, inv: Invocation[RingAttentionConfig, Array]): """Generate TPU-optimized candidate configurations for autotuning. TPU/Pallas kernels benefit from larger blocks for ring attention. Args: inv: Invocation object with arguments and metadata Returns: Iterable of TPU-optimized candidate configurations """ block_configs = [ (128, 128, 4, 1), (256, 256, 8, 2), (512, 512, 8, 2), ] candidates = [] for block_q, block_k, num_warps, num_stages in block_configs: fwd = FwdParams(q_blocksize=block_q, kv_blocksize=block_k, num_stages=num_stages, num_warps=num_warps) bwd = BwdParams(q_blocksize=block_q, kv_blocksize=block_k, num_stages=num_stages, num_warps=num_warps) candidates.append( RingAttentionConfig( fwd_params=fwd, bwd_params=bwd, platform="pallas", backend="tpu", ) ) return candidates
candidate_cfgs_shard_map_tpu = candidate_cfgs_tpu
_ring_executor: Executor[RingAttentionConfig, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy( allow_autotune=True, cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"), validate_backward=True, ), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("ring-attention", cfg_type=RingAttentionConfig), ) )
[docs]def ring_attention( query: Float[Array, "batch seq_len_q num_heads head_dim"], key: Float[Array, "batch seq_len_k num_kv_heads head_dim"], value: Float[Array, "batch seq_len_k num_kv_heads head_dim"], softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None, /, *, mask_info: MaskInfo | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fused_backward: bool = False, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: RingAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: """Execute ring attention with automatic optimization. Ring attention distributes attention computation across devices in a ring topology, enabling efficient processing of very long sequences through communication-efficient parallelization. Args: query: Query tensor [batch, seq_len_q, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_kv_heads, head_dim] value: Value tensor [batch, seq_len_k, num_kv_heads, head_dim] softmax_aux: Optional attention sink logits for long-context stability bias: Optional attention bias tensor mask_info: Optional MaskInfo containing attention mask and/or segment IDs mask_builder: Custom mask builder function(q_len, kv_len, num_heads, head_idx, num_reps) -> Mask sliding_window: Window size for local attention (int or (left, right) tuple) chunk_size: Chunk size for chunked causal attention (Llama4 style) causal: Whether to use causal masking logits_soft_cap: Soft capping value for logits softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) axis_name: Name of the axis for collective operations fused_backward: Whether to use fused backward kernel platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Optional configuration override mesh: JAX device mesh for shard_map execution (optional) in_specs: Input partition specs for shard_map (optional) out_specs: Output partition spec for shard_map (optional) Returns: Attention output with same shape as query Example: >>> >>> out = ring_attention(query, key, value, causal=True, axis_name="sp") >>> >>> >>> out = ring_attention( ... query, key, value, ... causal=True, ... sliding_window=1024, ... axis_name="sp", ... ) >>> >>> >>> out = ring_attention(..., platform="triton") """ q_segment_ids = None kv_segment_ids = None if mask_info is not None: q_segment_ids, kv_segment_ids = mask_info.get_or_compute_segment_ids() method = None if mesh is not None and in_specs is not None and out_specs is not None: method = "shard_map" if mask_info is None: in_specs = (*in_specs, None, None) else: shardings = mask_info.get_shardings(True, mesh=mesh) in_specs = (*in_specs, shardings.q_segment_ids, shardings.kv_segment_ids) assert mask_info.sequence_axis_name == axis_name, "mismatch between two sequence axis names!" return _ring_executor( RingAttention(), query=query, key=key, value=value, softmax_aux=softmax_aux, bias=bias, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, logits_soft_cap=logits_soft_cap, softmax_scale=softmax_scale, axis_name=axis_name, fused_backward=fused_backward, platform=platform, method=method, mesh=mesh, in_specs=in_specs, out_specs=out_specs, _cfg=cfg, )