# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Chunked cumulative sum operations for attention mechanisms.
This module provides efficient chunked cumulative sum operations used in
various attention mechanisms, particularly linear attention variants that
require running sums over sequence dimensions.
Key Operations:
- chunk_local_cumsum: Cumsum within fixed-size chunks (resets at boundaries)
- chunk_global_cumsum: Cumsum across entire sequences (respects boundaries)
Both operations support:
- Scalar (3D) and vector (4D) inputs
- Forward and reverse directions
- Softmax scaling
- Variable-length sequences via cu_seqlens
- Head-first or time-first tensor layouts
Typical Use Cases:
- Gated Linear Attention (GLA): Computing cumulative gate products
- Linear Attention: Accumulating key-value outer products
- RetNet: Computing retention decay factors
Example:
>>> from ejkernel.xla_utils import chunk_local_cumsum, chunk_global_cumsum
>>>
>>> # Local cumsum within 128-token chunks
>>> local_cumsum = chunk_local_cumsum(g, chunk_size=128)
>>>
>>> # Global cumsum respecting sequence boundaries
>>> global_cumsum = chunk_global_cumsum(s, cu_seqlens=cu_seqlens)
"""
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jaxtyping import DTypeLike
[docs]@partial(jit, static_argnames=["chunk_size", "reverse", "head_first"])
def chunk_local_cumsum_scalar(
g: jnp.ndarray,
chunk_size: int,
reverse: bool = False,
softmax_scale: float | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
) -> jnp.ndarray:
"""Compute local cumulative sum within chunks for 3D scalar inputs.
Performs cumulative sum within fixed-size chunks, resetting at each
chunk boundary. Supports forward and reverse directions.
Args:
g: Input tensor of shape [B, H, T] if head_first else [B, T, H].
chunk_size: Size of each chunk (must be power of 2).
reverse: If True, compute reverse cumsum within each chunk.
softmax_scale: Optional scaling factor applied to result.
head_first: If True, expects [B, H, T] layout; otherwise [B, T, H].
output_dtype: Optional output dtype (defaults to input dtype).
Returns:
Tensor with same shape containing chunked cumulative sums.
"""
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
assert chunk_size & (chunk_size - 1) == 0, "chunk_size must be a power of 2"
output_dtype = output_dtype or g.dtype
if not head_first:
g = jnp.transpose(g, (0, 2, 1))
T = g.shape[2]
pad_length = (chunk_size - T % chunk_size) % chunk_size
if pad_length > 0:
g = jnp.pad(g, ((0, 0), (0, 0), (0, pad_length)), mode="constant")
T_padded = T + pad_length
num_chunks = T_padded // chunk_size
g_chunked = g.reshape(B, H, num_chunks, chunk_size)
if reverse:
g_flipped = jnp.flip(g_chunked, axis=-1)
cumsum_flipped = jnp.cumsum(g_flipped, axis=-1)
result_chunked = jnp.flip(cumsum_flipped, axis=-1)
else:
result_chunked = jnp.cumsum(g_chunked, axis=-1)
if softmax_scale is not None:
result_chunked *= softmax_scale
result = result_chunked.reshape(B, H, T_padded)
if pad_length > 0:
result = result[:, :, :T]
if not head_first:
result = jnp.transpose(result, (0, 2, 1))
return result.astype(output_dtype)
[docs]@partial(jit, static_argnames=["chunk_size", "reverse", "head_first"])
def chunk_local_cumsum_vector(
g: jnp.ndarray,
chunk_size: int,
reverse: bool = False,
softmax_scale: float | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
) -> jnp.ndarray:
"""Compute local cumulative sum within chunks for 4D vector inputs.
Performs cumulative sum within fixed-size chunks for tensors with
an additional state dimension.
Args:
g: Input tensor of shape [B, H, T, S] if head_first else [B, T, H, S].
chunk_size: Size of each chunk (must be power of 2).
reverse: If True, compute reverse cumsum within each chunk.
softmax_scale: Optional scaling factor applied to result.
head_first: If True, expects [B, H, T, S] layout.
output_dtype: Optional output dtype (defaults to input dtype).
Returns:
Tensor with same shape containing chunked cumulative sums.
"""
if head_first:
B, H, T, _S = g.shape
else:
B, T, H, _S = g.shape
assert chunk_size & (chunk_size - 1) == 0, "chunk_size must be a power of 2"
output_dtype = output_dtype or g.dtype
if not head_first:
g = jnp.transpose(g, (0, 2, 1, 3))
T = g.shape[2]
pad_length = (chunk_size - T % chunk_size) % chunk_size
if pad_length > 0:
g = jnp.pad(g, ((0, 0), (0, 0), (0, pad_length), (0, 0)), mode="constant")
T_padded = T + pad_length
num_chunks = T_padded // chunk_size
g_chunked = g.reshape(B, H, num_chunks, chunk_size, g.shape[-1])
if reverse:
g_flipped = jnp.flip(g_chunked, axis=-2)
cumsum_flipped = jnp.cumsum(g_flipped, axis=-2)
result_chunked = jnp.flip(cumsum_flipped, axis=-2)
else:
result_chunked = jnp.cumsum(g_chunked, axis=-2)
if softmax_scale is not None:
result_chunked *= softmax_scale
result = result_chunked.reshape(B, H, T_padded, g.shape[-1])
if pad_length > 0:
result = result[:, :, :T, :]
if not head_first:
result = jnp.transpose(result, (0, 2, 1, 3))
return result.astype(output_dtype)
[docs]@partial(jit, static_argnames=["reverse", "head_first"])
def chunk_global_cumsum_scalar(
s: jnp.ndarray,
reverse: bool = False,
cu_seqlens: jnp.ndarray | None = None,
softmax_scale: float | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
) -> jnp.ndarray:
"""Compute global cumulative sum across sequences for 3D scalar inputs.
Performs cumulative sum across the entire sequence, with optional
support for variable-length sequences via cu_seqlens.
Args:
s: Input tensor of shape [B, H, T] if head_first else [B, T, H].
reverse: If True, compute reverse cumsum.
cu_seqlens: Optional cumulative sequence lengths for packed sequences.
softmax_scale: Optional scaling factor applied to result.
head_first: If True, expects [B, H, T] layout.
output_dtype: Optional output dtype (defaults to input dtype).
Returns:
Tensor with same shape containing global cumulative sums.
"""
output_dtype = output_dtype or s.dtype
time_axis = 2 if head_first else 1
if reverse:
s_flipped = jnp.flip(s, axis=time_axis)
result = jnp.cumsum(s_flipped, axis=time_axis)
else:
result = jnp.cumsum(s, axis=time_axis)
if cu_seqlens is not None:
boundary_indices = cu_seqlens[1:-1] - 1
gather_indices_shape = [1] * s.ndim
gather_indices_shape[time_axis] = len(boundary_indices)
gather_indices = boundary_indices.reshape(gather_indices_shape)
correction_values = jnp.take_along_axis(result, gather_indices, axis=time_axis)
zero_pad_shape = list(correction_values.shape)
zero_pad_shape[time_axis] = 1
full_correction_map = jnp.concatenate(
[jnp.zeros(zero_pad_shape, dtype=result.dtype), correction_values],
axis=time_axis,
)
total_len = s.shape[time_axis]
seq_ids = jnp.cumsum(jnp.zeros(total_len, dtype=jnp.int32).at[cu_seqlens[1:]].set(1))
id_shape = [1] * s.ndim
id_shape[time_axis] = total_len
seq_ids = seq_ids.reshape(id_shape)
correction_tensor = jnp.take_along_axis(full_correction_map, seq_ids, axis=time_axis)
result -= correction_tensor
if reverse:
result = jnp.flip(result, axis=time_axis)
if softmax_scale is not None:
result *= softmax_scale
return result.astype(output_dtype)
[docs]@partial(jit, static_argnames=["reverse", "head_first"])
def chunk_global_cumsum_vector(
s: jnp.ndarray,
reverse: bool = False,
cu_seqlens: jnp.ndarray | None = None,
softmax_scale: float | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
) -> jnp.ndarray:
"""
Perform global cumulative sum for vector values, with explicit support for cu_seqlens.
"""
output_dtype = output_dtype or s.dtype
time_axis = 2 if head_first else 1
if reverse:
s_flipped = jnp.flip(s, axis=time_axis)
result = jnp.cumsum(s_flipped, axis=time_axis)
else:
result = jnp.cumsum(s, axis=time_axis)
if cu_seqlens is not None:
boundary_indices = cu_seqlens[1:-1] - 1
gather_indices_shape = [1] * s.ndim
gather_indices_shape[time_axis] = len(boundary_indices)
gather_indices = boundary_indices.reshape(gather_indices_shape)
correction_values = jnp.take_along_axis(result, gather_indices, axis=time_axis)
zero_pad_shape = list(correction_values.shape)
zero_pad_shape[time_axis] = 1
full_correction_map = jnp.concatenate(
[jnp.zeros(zero_pad_shape, dtype=result.dtype), correction_values],
axis=time_axis,
)
total_len = s.shape[time_axis]
seq_ids = jnp.cumsum(jnp.zeros(total_len, dtype=jnp.int32).at[cu_seqlens[1:]].set(1))
id_shape = [1] * s.ndim
id_shape[time_axis] = total_len
seq_ids = seq_ids.reshape(id_shape)
correction_tensor = jnp.take_along_axis(full_correction_map, seq_ids, axis=time_axis)
result -= correction_tensor
if reverse:
result = jnp.flip(result, axis=time_axis)
if softmax_scale is not None:
result *= softmax_scale
return result.astype(output_dtype)
@partial(
jit,
static_argnames=[
"chunk_size",
"reverse",
"softmax_scale",
"head_first",
"output_dtype",
"is_vector",
],
)
def _chunk_local_cumsum_vmap_core(
g_padded_batched: jnp.ndarray,
mask: jnp.ndarray,
chunk_size: int,
reverse: bool,
softmax_scale: float | None,
head_first: bool,
output_dtype: DTypeLike | None,
is_vector: bool,
):
base_fn = chunk_local_cumsum_vector if is_vector else chunk_local_cumsum_scalar
vmapped_fn = vmap(base_fn, in_axes=(0, None, None, None, None, None), out_axes=0)
result_padded = vmapped_fn(g_padded_batched, chunk_size, reverse, softmax_scale, head_first, output_dtype)
return result_padded * mask
[docs]def chunk_local_cumsum(
g: jnp.ndarray,
chunk_size: int,
reverse: bool = False,
softmax_scale: float | None = None,
cu_seqlens: jnp.ndarray | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
**kwargs,
) -> jnp.ndarray:
"""Compute local cumulative sum within fixed-size chunks.
Main entry point for chunked local cumulative sum, automatically
dispatching to scalar or vector implementations based on input rank.
Args:
g: Input tensor of shape [B, T, H] or [B, T, H, S].
chunk_size: Size of each chunk (must be power of 2).
reverse: If True, compute reverse cumsum within each chunk.
softmax_scale: Optional scaling factor applied to result.
cu_seqlens: Optional cumulative sequence lengths for packed sequences.
head_first: If True, expects head dimension before time dimension.
output_dtype: Optional output dtype (defaults to input dtype).
**kwargs: Additional keyword arguments (ignored).
Returns:
Tensor with same shape containing chunked cumulative sums.
Note:
When cu_seqlens is provided, only batch size 1 is supported.
The function handles variable-length sequences by padding and masking.
"""
is_vector = g.ndim == 4
base_fn = chunk_local_cumsum_vector if is_vector else chunk_local_cumsum_scalar
if cu_seqlens is None:
return base_fn(g, chunk_size, reverse, softmax_scale, None, head_first, output_dtype)
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
seqlens = jnp.diff(cu_seqlens)
max_seq_len = jnp.max(seqlens)
num_seqs = len(seqlens)
mask_indices = jnp.arange(max_seq_len) < seqlens[:, None]
squeezed_g = g.squeeze(0)
other_dims = squeezed_g.shape[1:]
padded_g = jnp.zeros((num_seqs, max_seq_len, *other_dims), dtype=g.dtype)
def create_padded_batch(i, _):
start, length = cu_seqlens[i], seqlens[i]
seq_slice = jax.lax.dynamic_slice(squeezed_g, (start,) + (0,) * len(other_dims), (length, *other_dims))
return jax.lax.dynamic_update_slice(padded_g[i], seq_slice, (0,) * (len(other_dims) + 1))
g_padded_batched = jnp.stack([create_padded_batch(i, None) for i in range(num_seqs)], axis=0)
g_padded_batched = jnp.expand_dims(g_padded_batched, axis=0)
mask_shape = (num_seqs, max_seq_len) + (1,) * len(other_dims)
mask = mask_indices.reshape(mask_shape)
result_padded = _chunk_local_cumsum_vmap_core(
g_padded_batched,
mask,
chunk_size,
reverse,
softmax_scale,
head_first,
output_dtype,
is_vector,
)
result_flat = result_padded.reshape(-1, *result_padded.shape[-(len(other_dims)) :])
final_result = result_flat[mask_indices.flatten()]
return final_result[None, ...]
[docs]def chunk_global_cumsum(
s: jnp.ndarray,
reverse: bool = False,
cu_seqlens: jnp.ndarray | None = None,
softmax_scale: float | None = None,
head_first: bool = False,
output_dtype: DTypeLike | None = None,
) -> jnp.ndarray:
"""Compute global cumulative sum across sequences.
Main entry point for global cumulative sum, automatically dispatching
to scalar or vector implementations based on input rank.
Args:
s: Input tensor of shape [B, T, H] or [B, T, H, S].
reverse: If True, compute reverse cumsum.
cu_seqlens: Optional cumulative sequence lengths for packed sequences.
When provided, cumsum resets at sequence boundaries.
softmax_scale: Optional scaling factor applied to result.
head_first: If True, expects head dimension before time dimension.
output_dtype: Optional output dtype (defaults to input dtype).
Returns:
Tensor with same shape containing global cumulative sums.
Note:
With cu_seqlens, the cumulative sum respects sequence boundaries
and does not accumulate across different sequences in the batch.
"""
is_vector = s.ndim == 4
if is_vector:
return chunk_global_cumsum_vector(s, reverse, cu_seqlens, softmax_scale, head_first, output_dtype)
else:
return chunk_global_cumsum_scalar(s, reverse, cu_seqlens, softmax_scale, head_first, output_dtype)