# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for packed sequence processing in XLA.
This module provides efficient utilities for working with packed (variable-length)
sequences in JAX/XLA computations, commonly used in attention mechanisms.
Packed sequences are represented using cumulative sequence lengths (cu_seqlens),
which define the boundaries of each sequence in a flattened 1D tensor.
Key Concepts:
- cu_seqlens: [0, len1, len1+len2, ...] - cumulative start positions
- position_ids: Position within each sequence (0-indexed)
- sequence_ids: Which sequence each token belongs to (0-indexed)
- chunk_indices: For tiled processing with fixed chunk sizes
Functions:
cdiv: Ceiling division for computing block counts
prepare_lens: Extract individual lengths from cumulative lengths
prepare_position_ids: Generate per-token position indices
prepare_sequence_ids: Generate per-token sequence membership
prepare_token_indices: Combined (seq_id, pos_id) pairs
prepare_chunk_indices: Chunk-level indices for tiled attention
prepare_chunk_offsets: Cumulative chunk counts per sequence
identity_dtype_convert: Create identity function with dtype conversion on backward
Example:
>>> cu_seqlens = jnp.array([0, 3, 5, 9]) # 3 sequences
>>> lens = prepare_lens(cu_seqlens) # [3, 2, 4]
>>> pos_ids = prepare_position_ids(cu_seqlens) # [0, 1, 2, 0, 1, 0, 1, 2, 3]
"""
import jax
import jax.numpy as jnp
from jax import Array
from jaxtyping import Bool, DTypeLike, Int
[docs]def cdiv(a: Int[Array, "..."], b: int) -> Int[Array, "..."]:
"""Computes ceiling division for integers in a JAX-compatible way."""
return (a + b - 1) // b
[docs]def prepare_lens(cu_seqlens: Int[Array, "num_seqs_plus_one"]) -> Int[Array, "num_seqs"]:
"""
Calculates the lengths of individual sequences from cumulative sequence lengths.
Args:
cu_seqlens: A 1D array of cumulative sequence lengths (e.g., [0, len1, len1+len2, ...]).
Returns:
A 1D array of sequence lengths (e.g., [len1, len2, ...]).
"""
return cu_seqlens[1:] - cu_seqlens[:-1]
[docs]def prepare_lens_from_mask(mask: Bool[Array, "batch seq_len"]) -> Int[Array, "batch"]:
"""
Calculates the length of each sequence from a boolean attention mask.
Args:
mask: A 2D boolean attention mask (batch_size, seq_len).
Returns:
A 1D array of sequence lengths with dtype int32.
"""
return mask.sum(axis=-1, dtype=jnp.int32)
[docs]def prepare_cu_seqlens_from_mask(
mask: Bool[Array, "batch seq_len"], out_dtype: DTypeLike = jnp.int32
) -> Int[Array, "batch_plus_one"]:
"""
Creates cumulative sequence lengths from a boolean attention mask.
Args:
mask: A 2D boolean attention mask (batch_size, seq_len).
out_dtype: The desired dtype for the output array.
Returns:
A 1D array of cumulative sequence lengths (e.g., [0, len1, len1+len2, ...]).
"""
cumsum_lens = prepare_lens_from_mask(mask).cumsum(axis=0, dtype=out_dtype)
return jnp.pad(cumsum_lens, (1, 0))
[docs]def prepare_position_ids(cu_seqlens: Int[Array, "num_seqs_plus_one"]) -> Int[Array, "total_tokens"]:
"""
Generates position IDs for a batch of packed sequences.
This creates a single 1D array like [0, 1, 2, 0, 1, 0, 1, 2, 3] for sequences
of lengths [3, 2, 4].
Args:
cu_seqlens: A 1D array of cumulative sequence lengths.
Returns:
A 1D array of position IDs for the packed sequences.
"""
lens = prepare_lens(cu_seqlens)
total_length = cu_seqlens[-1]
indices = jnp.arange(total_length, dtype=cu_seqlens.dtype)
start_offsets = jnp.repeat(cu_seqlens[:-1], repeats=lens)
return indices - start_offsets
[docs]def prepare_sequence_ids(cu_seqlens: Int[Array, "num_seqs_plus_one"]) -> Int[Array, "total_tokens"]:
"""
Generates sequence IDs (0-indexed) for a batch of packed sequences.
Args:
cu_seqlens: A 1D array of cumulative sequence lengths.
Returns:
A 1D array of sequence IDs, e.g., [0, 0, 0, 1, 1, 2, 2, 2, 2].
"""
position_ids = prepare_position_ids(cu_seqlens)
return (position_ids == 0).cumsum(axis=0) - 1
[docs]def prepare_token_indices(cu_seqlens: Int[Array, "num_seqs_plus_one"]) -> Int[Array, "total_tokens 2"]:
"""
Generates (sequence_id, position_id) pairs for each token in the packed batch.
Args:
cu_seqlens: A 1D array of cumulative sequence lengths.
Returns:
A 2D array of shape (total_tokens, 2) where each row is [sequence_id, position_id].
"""
position_ids = prepare_position_ids(cu_seqlens)
sequence_ids = (position_ids == 0).cumsum(axis=0) - 1
stacked = jnp.stack([sequence_ids, position_ids], axis=1)
return stacked.astype(cu_seqlens.dtype)
[docs]def prepare_chunk_indices(cu_seqlens: Int[Array, "num_seqs_plus_one"], chunk_size: int) -> Int[Array, "total_chunks 2"]:
"""
Generates (sequence_id, chunk_id) pairs for each chunk in the packed batch.
Args:
cu_seqlens: A 1D array of cumulative sequence lengths.
chunk_size: The size of each chunk.
Returns:
A 2D array of shape (total_chunks, 2) where each row is [sequence_id, chunk_id_in_sequence].
"""
lens = prepare_lens(cu_seqlens)
num_chunks_per_seq = cdiv(lens, chunk_size)
total_chunks = num_chunks_per_seq.sum()
cu_chunks = jnp.pad(num_chunks_per_seq.cumsum(), (1, 0))
start_offsets = jnp.repeat(cu_chunks[:-1], repeats=num_chunks_per_seq)
indices = jnp.arange(total_chunks) - start_offsets
sequence_ids_for_chunks = (indices == 0).cumsum(axis=0) - 1
stacked = jnp.stack([sequence_ids_for_chunks, indices], axis=1)
return stacked.astype(cu_seqlens.dtype)
[docs]def prepare_chunk_offsets(
cu_seqlens: Int[Array, "num_seqs_plus_one"], chunk_size: int
) -> Int[Array, "num_seqs_plus_one"]:
"""
Computes the cumulative offsets of chunks in the packed batch.
Args:
cu_seqlens: A 1D array of cumulative sequence lengths.
chunk_size: The size of each chunk.
Returns:
A 1D array of cumulative chunk counts (e.g., [0, num_chunks_seq1, num_chunks_seq1 + num_chunks_seq2, ...]).
"""
num_chunks_per_seq = cdiv(prepare_lens(cu_seqlens), chunk_size)
zero = jnp.array([0], dtype=cu_seqlens.dtype)
concatenated = jnp.concatenate([zero, num_chunks_per_seq])
return concatenated.cumsum(axis=-1)
[docs]def identity_dtype_convert(dtype: jnp.dtype):
"""Create an identity function that converts gradients to a specific dtype.
Returns a function that passes inputs unchanged in the forward pass,
but converts gradients to the specified dtype during backpropagation.
This is useful for mixed-precision training where gradients need to
be accumulated in a specific precision.
Args:
dtype: The target dtype for gradient conversion.
Returns:
A JAX function that acts as identity in forward pass but
converts gradients to the specified dtype in backward pass.
Example:
>>> convert_to_fp32 = identity_dtype_convert(jnp.float32)
>>> result = convert_to_fp32(bf16_tensor) # Forward: unchanged
>>> # Backward: gradients will be converted to float32
"""
@jax.custom_vjp
def identity_fn(x):
return x
def identity_fn_fwd(x):
return x, None
def identity_fn_bwd(res, g):
return (g.astype(dtype),)
identity_fn.defvjp(identity_fn_fwd, identity_fn_bwd)
return identity_fn