Source code for ejkernel.kernels._triton.ring_attention._ring_kernel

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ring Flash Attention Kernel - wraps Triton flash attention for distributed ring topology."""

from __future__ import annotations

from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax import lax

from ejkernel.ops import BwdParams, FwdParams

from ..flash_attention._triton_impl_bwd import _bwd_attention_kernel_call
from ..flash_attention._triton_impl_fwd import _fwd_attention_kernel_call

# ln(2) constant for converting between log2 and natural log
LN2 = 0.6931471805599453


[docs]class RingFlashResiduals(NamedTuple): """Residuals saved from forward pass for backward computation.""" q: jax.Array k: jax.Array v: jax.Array bias: jax.Array | None attention_mask: jax.Array | None o: jax.Array lse: jax.Array # In natural log space dropout_seed: int | None
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13)) def ring_flash_attention_call( query: jax.Array, key: jax.Array, value: jax.Array, attention_mask: jax.Array | None, bias: jax.Array | None, softmax_scale: float | None, dropout_prob: float, causal: bool, dropout_seed: int | None, fwd_params: FwdParams | None, bwd_params: BwdParams | None, sliding_window: int | tuple[int, int] | None, logits_soft_cap: float | None, axis_name: str | None, ) -> jax.Array: """Ring flash attention with custom VJP for efficient gradients. Args: query: Query tensor [batch, seq_len_q, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_kv_heads, head_dim] value: Value tensor [batch, seq_len_k, num_kv_heads, head_dim] attention_mask: Optional attention mask bias: Optional attention bias softmax_scale: Scale for attention scores dropout_prob: Dropout probability causal: Whether to use causal masking dropout_seed: Random seed for dropout fwd_params: Forward pass block size parameters bwd_params: Backward pass block size parameters sliding_window: Sliding window size logits_soft_cap: Soft cap value for logits axis_name: Name of axis for ring communication Returns: Output tensor [batch, seq_len_q, num_heads, head_dim] """ o, _ = _ring_flash_attention_fwd( query, key, value, attention_mask, bias, softmax_scale, dropout_prob, causal, dropout_seed, fwd_params, bwd_params, sliding_window, logits_soft_cap, axis_name, ) return o def _ring_flash_attention_fwd( query: jax.Array, key: jax.Array, value: jax.Array, attention_mask: jax.Array | None, bias: jax.Array | None, softmax_scale: float | None, dropout_prob: float, causal: bool, dropout_seed: int | None, fwd_params: FwdParams | None, bwd_params: BwdParams | None, sliding_window: int | tuple[int, int] | None, logits_soft_cap: float | None, axis_name: str | None, ) -> tuple[jax.Array, RingFlashResiduals]: """Forward pass of ring flash attention. Uses online softmax to combine attention outputs from different ring positions. """ batch = query.shape[0] q_seq_len = query.shape[1] num_heads = query.shape[2] # Get ring size if axis_name is not None: axis_size = lax.psum(1, axis_name) else: axis_size = 1 # Initialize accumulators o = jnp.zeros_like(query) lse = jnp.full((batch, num_heads, q_seq_len), -jnp.inf, dtype=jnp.float32) def scan_ring(carry, idx): o_acc, lse_acc, k_curr, v_curr = carry # Call flash attention forward kernel o_chunk, lse_chunk_log2 = _fwd_attention_kernel_call( q=query, k=k_curr, v=v_curr, attention_mask=attention_mask, bias=bias, softmax_scale=softmax_scale, dropout_prob=dropout_prob, causal=causal, dropout_seed=dropout_seed, fwd_params=fwd_params, bwd_params=bwd_params, cum_seqlens_q=None, cum_seqlens_k=None, sliding_window=sliding_window, logits_soft_cap=logits_soft_cap, softmax_aux=None, # Attention sinks not supported in ring mode yet ) # Convert LSE from log2 to natural log lse_chunk = lse_chunk_log2 * LN2 # Handle padding: lse shape is (batch, heads, max_seqlen_q_rounded) lse_chunk = lse_chunk[..., :q_seq_len] # Online softmax combination lse_max = jnp.maximum(lse_acc, lse_chunk) alpha = jnp.exp(lse_acc - lse_max) beta = jnp.exp(lse_chunk - lse_max) sum_weights = alpha + beta # Avoid division by zero sum_weights_safe = jnp.where(sum_weights == 0, 1.0, sum_weights) # Update output with weighted combination # Transpose o_chunk to match lse shape broadcasting [batch, heads, seq] -> [batch, seq, heads] alpha_expanded = jnp.transpose(alpha, (0, 2, 1))[..., None] # [batch, seq, heads, 1] beta_expanded = jnp.transpose(beta, (0, 2, 1))[..., None] sum_weights_expanded = jnp.transpose(sum_weights_safe, (0, 2, 1))[..., None] o_next = (alpha_expanded * o_acc + beta_expanded * o_chunk) / sum_weights_expanded # Update log-sum-exp lse_next = lse_max + jnp.log(jnp.where(sum_weights == 0, 1.0, sum_weights)) # Rotate K, V to next device in ring if axis_name is not None: perm = [(i, (i + 1) % axis_size) for i in range(axis_size)] k_next = lax.ppermute(k_curr, axis_name, perm) v_next = lax.ppermute(v_curr, axis_name, perm) else: k_next, v_next = k_curr, v_curr return (o_next, lse_next, k_next, v_next), None (o, lse, _, _), _ = lax.scan(scan_ring, (o, lse, key, value), jnp.arange(axis_size)) residuals = RingFlashResiduals( q=query, k=key, v=value, bias=bias, attention_mask=attention_mask, o=o, lse=lse, dropout_seed=dropout_seed, ) return o, residuals def _ring_flash_attention_bwd( softmax_scale: float | None, dropout_prob: float, causal: bool, dropout_seed: int | None, fwd_params: FwdParams | None, bwd_params: BwdParams | None, sliding_window: int | tuple[int, int] | None, logits_soft_cap: float | None, axis_name: str | None, res: RingFlashResiduals, do: jax.Array, ) -> tuple[jax.Array, jax.Array, jax.Array, None, None]: """Backward pass of ring flash attention.""" q, k, v, bias, attention_mask, o, lse, dropout_seed_res = res del dropout_seed_res # Use the one from nondiff_argnums if axis_name is not None: axis_size = lax.psum(1, axis_name) else: axis_size = 1 # Initialize gradient accumulators dq = jnp.zeros_like(q, dtype=jnp.float32) dk = jnp.zeros_like(k, dtype=jnp.float32) dv = jnp.zeros_like(v, dtype=jnp.float32) # Convert LSE back to log2 for backward kernel (it expects log2 space) lse_log2 = lse / LN2 def scan_ring_bwd(carry, idx): dq_acc, dk_acc, dv_acc, k_curr, v_curr = carry # Compute gradients using flash attention backward kernel dq_chunk, dk_chunk, dv_chunk = _bwd_attention_kernel_call( dO=do, q=q, k=k_curr, v=v_curr, bias=bias, attention_mask=attention_mask, o=o, M=lse_log2, dropout_prob=dropout_prob, causal=causal, fwd_params=fwd_params, bwd_params=bwd_params, dropout_seed=dropout_seed, softmax_scale=softmax_scale, sliding_window=sliding_window, cum_seqlens_k=None, cum_seqlens_q=None, logits_soft_cap=logits_soft_cap, ) dq_acc = dq_acc + dq_chunk.astype(jnp.float32) dk_acc = dk_acc + dk_chunk.astype(jnp.float32) dv_acc = dv_acc + dv_chunk.astype(jnp.float32) # Rotate K, V and their gradients if axis_name is not None: perm = [(i, (i + 1) % axis_size) for i in range(axis_size)] k_next = lax.ppermute(k_curr, axis_name, perm) v_next = lax.ppermute(v_curr, axis_name, perm) dk_acc = lax.ppermute(dk_acc, axis_name, perm) dv_acc = lax.ppermute(dv_acc, axis_name, perm) else: k_next, v_next = k_curr, v_curr return (dq_acc, dk_acc, dv_acc, k_next, v_next), None (dq, dk, dv, _, _), _ = lax.scan(scan_ring_bwd, (dq, dk, dv, k, v), jnp.arange(axis_size)) # Cast back to input dtypes dq = dq.astype(q.dtype) dk = dk.astype(k.dtype) dv = dv.astype(v.dtype) return dq, dk, dv, None, None ring_flash_attention_call.defvjp(_ring_flash_attention_fwd, _ring_flash_attention_bwd)