ejkernel.xla_utils.shardings

ejkernel.xla_utils.shardings#

Sharding utilities for distributed JAX computation.

This module provides utilities for managing array shardings across distributed devices, with automatic correction of partition specifications based on array shapes and mesh configurations.

Key Functions:

get_corrected_named_sharding: Create valid shardings based on shape/mesh constraints reorder_sequence: Reorder sequence dimensions for ring attention patterns

Sharding Correction:

The get_corrected_named_sharding function automatically adjusts PartitionSpecs to ensure validity based on: - Axis names present in the current mesh - Divisibility of array dimensions by mesh axis sizes - Proper handling of multi-axis sharding

Ring Attention Reordering:

The reorder_sequence function rearranges sequence dimensions to enable efficient ring attention communication patterns, alternating between forward and backward sequence chunks.

Example

>>> from ejkernel.xla_utils import get_corrected_named_sharding
>>> from jax.sharding import PartitionSpec, Mesh
>>>
>>> mesh = Mesh(devices, ('dp', 'mp'))
>>> shape = (8, 1024, 512)
>>> spec = PartitionSpec('dp', None, 'mp')
>>> sharding = get_corrected_named_sharding(shape, spec, mesh)
ejkernel.xla_utils.shardings.get_corrected_named_sharding(shape: tuple[int, ...], partition_spec: PartitionSpec, mesh: Mesh) NamedSharding[source]#

Calculates the corrected PartitionSpec based on shape and mesh, returns NamedSharding.

This function takes an array shape and a desired PartitionSpec. It determines the effective PartitionSpec by correcting the input based on:

  • Axis names present in the current mesh.

  • Divisibility of array dimensions by the product of corresponding mesh axis sizes.

It does NOT correct based on mesh axes having size 1, allowing such axes to persist in the spec if explicitly provided and divisibility holds.

Parameters
  • shape – The shape of the target JAX array.

  • partition_spec – The desired PartitionSpec.

  • raise_mesh_error – If True, raises an error if no mesh is active. If False, returns a replicated NamedSharding on an empty mesh if no mesh is found.

Returns

A NamedSharding object containing the current mesh and the corrected PartitionSpec.

Raises

AssertionError – If no mesh is active and raise_mesh_error is True.

ejkernel.xla_utils.shardings.reorder_sequence(tensor, cp_size: int, seq_dim: int = 1, to_contiguous: bool = False)[source]#

Reorder sequence dimension for ring attention communication patterns.

Rearranges the sequence dimension to enable efficient ring attention communication, alternating between forward and backward sequence chunks to minimize communication overhead during context parallel processing.

Parameters
  • tensor – Input tensor with a sequence dimension to reorder.

  • cp_size – Context parallelism size (must be even).

  • seq_dim – Dimension index of the sequence axis (default: 1).

  • to_contiguous – If True, reorder for contiguous memory layout; if False, reorder for ring attention pattern.

Returns

Tensor with reordered sequence dimension for ring communication.

Raises

ValueError – If cp_size is not even or seq_len not divisible by 2*cp_size.

Note

The reordering interleaves forward and backward chunks to enable efficient bidirectional communication in ring attention patterns.