ejkernel.xla_utils.cumsum#
Chunked cumulative sum operations for attention mechanisms.
This module provides efficient chunked cumulative sum operations used in various attention mechanisms, particularly linear attention variants that require running sums over sequence dimensions.
- Key Operations:
chunk_local_cumsum: Cumsum within fixed-size chunks (resets at boundaries)
chunk_global_cumsum: Cumsum across entire sequences (respects boundaries)
- Both operations support:
Scalar (3D) and vector (4D) inputs
Forward and reverse directions
Softmax scaling
Variable-length sequences via cu_seqlens
Head-first or time-first tensor layouts
- Typical Use Cases:
Gated Linear Attention (GLA): Computing cumulative gate products
Linear Attention: Accumulating key-value outer products
RetNet: Computing retention decay factors
Example
>>> from ejkernel.xla_utils import chunk_local_cumsum, chunk_global_cumsum
>>>
>>> # Local cumsum within 128-token chunks
>>> local_cumsum = chunk_local_cumsum(g, chunk_size=128)
>>>
>>> # Global cumsum respecting sequence boundaries
>>> global_cumsum = chunk_global_cumsum(s, cu_seqlens=cu_seqlens)
- ejkernel.xla_utils.cumsum.chunk_global_cumsum(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#
Compute global cumulative sum across sequences.
Main entry point for global cumulative sum, automatically dispatching to scalar or vector implementations based on input rank.
- Parameters
s – Input tensor of shape [B, T, H] or [B, T, H, S].
reverse – If True, compute reverse cumsum.
cu_seqlens – Optional cumulative sequence lengths for packed sequences. When provided, cumsum resets at sequence boundaries.
softmax_scale – Optional scaling factor applied to result.
head_first – If True, expects head dimension before time dimension.
output_dtype – Optional output dtype (defaults to input dtype).
- Returns
Tensor with same shape containing global cumulative sums.
Note
With cu_seqlens, the cumulative sum respects sequence boundaries and does not accumulate across different sequences in the batch.
- ejkernel.xla_utils.cumsum.chunk_global_cumsum_scalar(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#
Compute global cumulative sum across sequences for 3D scalar inputs.
Performs cumulative sum across the entire sequence, with optional support for variable-length sequences via cu_seqlens.
- Parameters
s – Input tensor of shape [B, H, T] if head_first else [B, T, H].
reverse – If True, compute reverse cumsum.
cu_seqlens – Optional cumulative sequence lengths for packed sequences.
softmax_scale – Optional scaling factor applied to result.
head_first – If True, expects [B, H, T] layout.
output_dtype – Optional output dtype (defaults to input dtype).
- Returns
Tensor with same shape containing global cumulative sums.
- ejkernel.xla_utils.cumsum.chunk_global_cumsum_vector(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#
Perform global cumulative sum for vector values, with explicit support for cu_seqlens.
- ejkernel.xla_utils.cumsum.chunk_local_cumsum(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, cu_seqlens: jax.jaxlib._jax.Array | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None, **kwargs) Array[source]#
Compute local cumulative sum within fixed-size chunks.
Main entry point for chunked local cumulative sum, automatically dispatching to scalar or vector implementations based on input rank.
- Parameters
g – Input tensor of shape [B, T, H] or [B, T, H, S].
chunk_size – Size of each chunk (must be power of 2).
reverse – If True, compute reverse cumsum within each chunk.
softmax_scale – Optional scaling factor applied to result.
cu_seqlens – Optional cumulative sequence lengths for packed sequences.
head_first – If True, expects head dimension before time dimension.
output_dtype – Optional output dtype (defaults to input dtype).
**kwargs – Additional keyword arguments (ignored).
- Returns
Tensor with same shape containing chunked cumulative sums.
Note
When cu_seqlens is provided, only batch size 1 is supported. The function handles variable-length sequences by padding and masking.
- ejkernel.xla_utils.cumsum.chunk_local_cumsum_scalar(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#
Compute local cumulative sum within chunks for 3D scalar inputs.
Performs cumulative sum within fixed-size chunks, resetting at each chunk boundary. Supports forward and reverse directions.
- Parameters
g – Input tensor of shape [B, H, T] if head_first else [B, T, H].
chunk_size – Size of each chunk (must be power of 2).
reverse – If True, compute reverse cumsum within each chunk.
softmax_scale – Optional scaling factor applied to result.
head_first – If True, expects [B, H, T] layout; otherwise [B, T, H].
output_dtype – Optional output dtype (defaults to input dtype).
- Returns
Tensor with same shape containing chunked cumulative sums.
- ejkernel.xla_utils.cumsum.chunk_local_cumsum_vector(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#
Compute local cumulative sum within chunks for 4D vector inputs.
Performs cumulative sum within fixed-size chunks for tensors with an additional state dimension.
- Parameters
g – Input tensor of shape [B, H, T, S] if head_first else [B, T, H, S].
chunk_size – Size of each chunk (must be power of 2).
reverse – If True, compute reverse cumsum within each chunk.
softmax_scale – Optional scaling factor applied to result.
head_first – If True, expects [B, H, T, S] layout.
output_dtype – Optional output dtype (defaults to input dtype).
- Returns
Tensor with same shape containing chunked cumulative sums.