ejkernel.xla_utils.utils#

Utility functions for packed sequence processing in XLA.

This module provides efficient utilities for working with packed (variable-length) sequences in JAX/XLA computations, commonly used in attention mechanisms.

Packed sequences are represented using cumulative sequence lengths (cu_seqlens), which define the boundaries of each sequence in a flattened 1D tensor.

Key Concepts:
  • cu_seqlens: [0, len1, len1+len2, …] - cumulative start positions

  • position_ids: Position within each sequence (0-indexed)

  • sequence_ids: Which sequence each token belongs to (0-indexed)

  • chunk_indices: For tiled processing with fixed chunk sizes

Functions:

cdiv: Ceiling division for computing block counts prepare_lens: Extract individual lengths from cumulative lengths prepare_position_ids: Generate per-token position indices prepare_sequence_ids: Generate per-token sequence membership prepare_token_indices: Combined (seq_id, pos_id) pairs prepare_chunk_indices: Chunk-level indices for tiled attention prepare_chunk_offsets: Cumulative chunk counts per sequence identity_dtype_convert: Create identity function with dtype conversion on backward

Example

>>> cu_seqlens = jnp.array([0, 3, 5, 9])  # 3 sequences
>>> lens = prepare_lens(cu_seqlens)  # [3, 2, 4]
>>> pos_ids = prepare_position_ids(cu_seqlens)  # [0, 1, 2, 0, 1, 0, 1, 2, 3]
ejkernel.xla_utils.utils.cdiv(a: Int[jaxlib._jax.Array, '...'], b: int) Int[jaxlib._jax.Array, '...'][source]#

Computes ceiling division for integers in a JAX-compatible way.

ejkernel.xla_utils.utils.identity_dtype_convert(dtype: dtype)[source]#

Create an identity function that converts gradients to a specific dtype.

Returns a function that passes inputs unchanged in the forward pass, but converts gradients to the specified dtype during backpropagation. This is useful for mixed-precision training where gradients need to be accumulated in a specific precision.

Parameters

dtype – The target dtype for gradient conversion.

Returns

A JAX function that acts as identity in forward pass but converts gradients to the specified dtype in backward pass.

Example

>>> convert_to_fp32 = identity_dtype_convert(jnp.float32)
>>> result = convert_to_fp32(bf16_tensor)  # Forward: unchanged
>>> # Backward: gradients will be converted to float32
ejkernel.xla_utils.utils.prepare_chunk_indices(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one'], chunk_size: int) Int[jaxlib._jax.Array, 'total_chunks 2'][source]#

Generates (sequence_id, chunk_id) pairs for each chunk in the packed batch.

Parameters
  • cu_seqlens – A 1D array of cumulative sequence lengths.

  • chunk_size – The size of each chunk.

Returns

A 2D array of shape (total_chunks, 2) where each row is [sequence_id, chunk_id_in_sequence].

ejkernel.xla_utils.utils.prepare_chunk_offsets(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one'], chunk_size: int) Int[jaxlib._jax.Array, 'num_seqs_plus_one'][source]#

Computes the cumulative offsets of chunks in the packed batch.

Parameters
  • cu_seqlens – A 1D array of cumulative sequence lengths.

  • chunk_size – The size of each chunk.

Returns

A 1D array of cumulative chunk counts (e.g., [0, num_chunks_seq1, num_chunks_seq1 + num_chunks_seq2, …]).

ejkernel.xla_utils.utils.prepare_cu_seqlens_from_mask(mask: ~jaxtyping.Bool[jaxlib._jax.Array, 'batch seq_len'], out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>) Int[jaxlib._jax.Array, 'batch_plus_one'][source]#

Creates cumulative sequence lengths from a boolean attention mask.

Parameters
  • mask – A 2D boolean attention mask (batch_size, seq_len).

  • out_dtype – The desired dtype for the output array.

Returns

A 1D array of cumulative sequence lengths (e.g., [0, len1, len1+len2, …]).

ejkernel.xla_utils.utils.prepare_lens(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one']) Int[jaxlib._jax.Array, 'num_seqs'][source]#

Calculates the lengths of individual sequences from cumulative sequence lengths.

Parameters

cu_seqlens – A 1D array of cumulative sequence lengths (e.g., [0, len1, len1+len2, …]).

Returns

A 1D array of sequence lengths (e.g., [len1, len2, …]).

ejkernel.xla_utils.utils.prepare_lens_from_mask(mask: Bool[jaxlib._jax.Array, 'batch seq_len']) Int[jaxlib._jax.Array, 'batch'][source]#

Calculates the length of each sequence from a boolean attention mask.

Parameters

mask – A 2D boolean attention mask (batch_size, seq_len).

Returns

A 1D array of sequence lengths with dtype int32.

ejkernel.xla_utils.utils.prepare_position_ids(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one']) Int[jaxlib._jax.Array, 'total_tokens'][source]#

Generates position IDs for a batch of packed sequences.

This creates a single 1D array like [0, 1, 2, 0, 1, 0, 1, 2, 3] for sequences of lengths [3, 2, 4].

Parameters

cu_seqlens – A 1D array of cumulative sequence lengths.

Returns

A 1D array of position IDs for the packed sequences.

ejkernel.xla_utils.utils.prepare_sequence_ids(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one']) Int[jaxlib._jax.Array, 'total_tokens'][source]#

Generates sequence IDs (0-indexed) for a batch of packed sequences.

Parameters

cu_seqlens – A 1D array of cumulative sequence lengths.

Returns

A 1D array of sequence IDs, e.g., [0, 0, 0, 1, 1, 2, 2, 2, 2].

ejkernel.xla_utils.utils.prepare_token_indices(cu_seqlens: Int[jaxlib._jax.Array, 'num_seqs_plus_one']) Int[jaxlib._jax.Array, 'total_tokens 2'][source]#

Generates (sequence_id, position_id) pairs for each token in the packed batch.

Parameters

cu_seqlens – A 1D array of cumulative sequence lengths.

Returns

A 2D array of shape (total_tokens, 2) where each row is [sequence_id, position_id].