ejkernel.xla_utils.shardings#
Sharding utilities for distributed JAX computation.
This module provides utilities for managing array shardings across distributed devices, with automatic correction of partition specifications based on array shapes and mesh configurations.
- Key Functions:
get_corrected_named_sharding: Create valid shardings based on shape/mesh constraints reorder_sequence: Reorder sequence dimensions for ring attention patterns
- Sharding Correction:
The get_corrected_named_sharding function automatically adjusts PartitionSpecs to ensure validity based on: - Axis names present in the current mesh - Divisibility of array dimensions by mesh axis sizes - Proper handling of multi-axis sharding
- Ring Attention Reordering:
The reorder_sequence function rearranges sequence dimensions to enable efficient ring attention communication patterns, alternating between forward and backward sequence chunks.
Example
>>> from ejkernel.xla_utils import get_corrected_named_sharding
>>> from jax.sharding import PartitionSpec, Mesh
>>>
>>> mesh = Mesh(devices, ('dp', 'mp'))
>>> shape = (8, 1024, 512)
>>> spec = PartitionSpec('dp', None, 'mp')
>>> sharding = get_corrected_named_sharding(shape, spec, mesh)
- ejkernel.xla_utils.shardings.get_corrected_named_sharding(shape: tuple[int, ...], partition_spec: PartitionSpec, mesh: Mesh) NamedSharding[source]#
Calculates the corrected PartitionSpec based on shape and mesh, returns NamedSharding.
This function takes an array shape and a desired PartitionSpec. It determines the effective PartitionSpec by correcting the input based on:
Axis names present in the current mesh.
Divisibility of array dimensions by the product of corresponding mesh axis sizes.
It does NOT correct based on mesh axes having size 1, allowing such axes to persist in the spec if explicitly provided and divisibility holds.
- Parameters
shape – The shape of the target JAX array.
partition_spec – The desired PartitionSpec.
raise_mesh_error – If True, raises an error if no mesh is active. If False, returns a replicated NamedSharding on an empty mesh if no mesh is found.
- Returns
A NamedSharding object containing the current mesh and the corrected PartitionSpec.
- Raises
AssertionError – If no mesh is active and raise_mesh_error is True.
- ejkernel.xla_utils.shardings.reorder_sequence(tensor, cp_size: int, seq_dim: int = 1, to_contiguous: bool = False)[source]#
Reorder sequence dimension for ring attention communication patterns.
Rearranges the sequence dimension to enable efficient ring attention communication, alternating between forward and backward sequence chunks to minimize communication overhead during context parallel processing.
- Parameters
tensor – Input tensor with a sequence dimension to reorder.
cp_size – Context parallelism size (must be even).
seq_dim – Dimension index of the sequence axis (default: 1).
to_contiguous – If True, reorder for contiguous memory layout; if False, reorder for ring attention pattern.
- Returns
Tensor with reordered sequence dimension for ring communication.
- Raises
ValueError – If cp_size is not even or seq_len not divisible by 2*cp_size.
Note
The reordering interleaves forward and backward chunks to enable efficient bidirectional communication in ring attention patterns.