ejkernel.xla_utils.cumsum#

Chunked cumulative sum operations for attention mechanisms.

This module provides efficient chunked cumulative sum operations used in various attention mechanisms, particularly linear attention variants that require running sums over sequence dimensions.

Key Operations:
  • chunk_local_cumsum: Cumsum within fixed-size chunks (resets at boundaries)

  • chunk_global_cumsum: Cumsum across entire sequences (respects boundaries)

Both operations support:
  • Scalar (3D) and vector (4D) inputs

  • Forward and reverse directions

  • Softmax scaling

  • Variable-length sequences via cu_seqlens

  • Head-first or time-first tensor layouts

Typical Use Cases:
  • Gated Linear Attention (GLA): Computing cumulative gate products

  • Linear Attention: Accumulating key-value outer products

  • RetNet: Computing retention decay factors

Example

>>> from ejkernel.xla_utils import chunk_local_cumsum, chunk_global_cumsum
>>>
>>> # Local cumsum within 128-token chunks
>>> local_cumsum = chunk_local_cumsum(g, chunk_size=128)
>>>
>>> # Global cumsum respecting sequence boundaries
>>> global_cumsum = chunk_global_cumsum(s, cu_seqlens=cu_seqlens)
ejkernel.xla_utils.cumsum.chunk_global_cumsum(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#

Compute global cumulative sum across sequences.

Main entry point for global cumulative sum, automatically dispatching to scalar or vector implementations based on input rank.

Parameters
  • s – Input tensor of shape [B, T, H] or [B, T, H, S].

  • reverse – If True, compute reverse cumsum.

  • cu_seqlens – Optional cumulative sequence lengths for packed sequences. When provided, cumsum resets at sequence boundaries.

  • softmax_scale – Optional scaling factor applied to result.

  • head_first – If True, expects head dimension before time dimension.

  • output_dtype – Optional output dtype (defaults to input dtype).

Returns

Tensor with same shape containing global cumulative sums.

Note

With cu_seqlens, the cumulative sum respects sequence boundaries and does not accumulate across different sequences in the batch.

ejkernel.xla_utils.cumsum.chunk_global_cumsum_scalar(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#

Compute global cumulative sum across sequences for 3D scalar inputs.

Performs cumulative sum across the entire sequence, with optional support for variable-length sequences via cu_seqlens.

Parameters
  • s – Input tensor of shape [B, H, T] if head_first else [B, T, H].

  • reverse – If True, compute reverse cumsum.

  • cu_seqlens – Optional cumulative sequence lengths for packed sequences.

  • softmax_scale – Optional scaling factor applied to result.

  • head_first – If True, expects [B, H, T] layout.

  • output_dtype – Optional output dtype (defaults to input dtype).

Returns

Tensor with same shape containing global cumulative sums.

ejkernel.xla_utils.cumsum.chunk_global_cumsum_vector(s: Array, reverse: bool = False, cu_seqlens: jax.jaxlib._jax.Array | None = None, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#

Perform global cumulative sum for vector values, with explicit support for cu_seqlens.

ejkernel.xla_utils.cumsum.chunk_local_cumsum(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, cu_seqlens: jax.jaxlib._jax.Array | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None, **kwargs) Array[source]#

Compute local cumulative sum within fixed-size chunks.

Main entry point for chunked local cumulative sum, automatically dispatching to scalar or vector implementations based on input rank.

Parameters
  • g – Input tensor of shape [B, T, H] or [B, T, H, S].

  • chunk_size – Size of each chunk (must be power of 2).

  • reverse – If True, compute reverse cumsum within each chunk.

  • softmax_scale – Optional scaling factor applied to result.

  • cu_seqlens – Optional cumulative sequence lengths for packed sequences.

  • head_first – If True, expects head dimension before time dimension.

  • output_dtype – Optional output dtype (defaults to input dtype).

  • **kwargs – Additional keyword arguments (ignored).

Returns

Tensor with same shape containing chunked cumulative sums.

Note

When cu_seqlens is provided, only batch size 1 is supported. The function handles variable-length sequences by padding and masking.

ejkernel.xla_utils.cumsum.chunk_local_cumsum_scalar(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#

Compute local cumulative sum within chunks for 3D scalar inputs.

Performs cumulative sum within fixed-size chunks, resetting at each chunk boundary. Supports forward and reverse directions.

Parameters
  • g – Input tensor of shape [B, H, T] if head_first else [B, T, H].

  • chunk_size – Size of each chunk (must be power of 2).

  • reverse – If True, compute reverse cumsum within each chunk.

  • softmax_scale – Optional scaling factor applied to result.

  • head_first – If True, expects [B, H, T] layout; otherwise [B, T, H].

  • output_dtype – Optional output dtype (defaults to input dtype).

Returns

Tensor with same shape containing chunked cumulative sums.

ejkernel.xla_utils.cumsum.chunk_local_cumsum_vector(g: Array, chunk_size: int, reverse: bool = False, softmax_scale: float | None = None, head_first: bool = False, output_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None) Array[source]#

Compute local cumulative sum within chunks for 4D vector inputs.

Performs cumulative sum within fixed-size chunks for tensors with an additional state dimension.

Parameters
  • g – Input tensor of shape [B, H, T, S] if head_first else [B, T, H, S].

  • chunk_size – Size of each chunk (must be power of 2).

  • reverse – If True, compute reverse cumsum within each chunk.

  • softmax_scale – Optional scaling factor applied to result.

  • head_first – If True, expects [B, H, T, S] layout.

  • output_dtype – Optional output dtype (defaults to input dtype).

Returns

Tensor with same shape containing chunked cumulative sums.