ejkernel.utils#
Utility functions for ejKernel library.
This module provides a comprehensive collection of utility functions for kernel development, including mathematical operations, array manipulation, hardware detection, performance testing, and distributed synchronization utilities.
The utilities are designed to support both Triton and JAX-based kernel implementations with focus on GPU architectures (CDNA, RDNA) and distributed training scenarios.
- ejkernel.utils.arch_supports_fp8()[source]#
Check if current GPU architecture supports FP8 operations.
- Returns
True if running on AMD gfx942 architecture with FP8 support, False otherwise.
- ejkernel.utils.assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-06)[source]#
Assert that two arrays are close within tolerance.
Compares arrays using both absolute and relative error thresholds, with options for warnings vs assertions.
- Parameters
prefix – Message prefix for error reporting.
ref – Reference array for comparison.
tri – Array to test against reference.
ratio – Maximum allowed error ratio.
warning – If True, issue warning instead of assertion on failure.
err_atol – Absolute tolerance threshold. Defaults to 1e-6.
- Raises
AssertionError – If arrays differ beyond tolerance and warning=False.
Example
>>> ref = jnp.ones((10,)) >>> test = ref + 1e-7 >>> assert_close("Test", ref, test, ratio=0.01)
- ejkernel.utils.barrier_sync(timeout: float = 200)[source]#
Synchronize all JAX processes at a barrier point.
Blocks execution until all processes in the distributed JAX runtime reach this barrier. This is essential for ensuring consistency across distributed training, especially before/after collective operations or checkpointing.
The function uses a global counter to create unique barrier names, allowing multiple barriers to be used sequentially without conflicts.
- Parameters
timeout – Maximum time to wait for all processes to reach the barrier, in seconds. Defaults to 200 seconds (3.33 minutes). If the timeout is exceeded, a RuntimeError will be raised by the underlying JAX distributed client.
- Returns
None
- Raises
RuntimeError – If the JAX distributed client is not initialized. This typically means JAX was not started in distributed mode or the distributed runtime failed to initialize.
Note
This function is a no-op when running with a single process (jax.process_count() == 1), allowing code to work seamlessly in both single and multi-process environments.
Each call increments a global counter to ensure unique barrier names, preventing conflicts when multiple barriers are used in sequence.
The timeout is converted to milliseconds for the underlying JAX API.
Example
>>> >>> model = train_step(model, batch) >>> barrier_sync() >>> if jax.process_index() == 0: ... save_checkpoint(model) >>> barrier_sync()
>>> >>> barrier_sync(timeout=600)
Warning
Ensure all processes call barrier_sync() the same number of times and in the same order, or deadlocks may occur. Conditional barriers based on process rank should be avoided.
- ejkernel.utils.calculate_blocksize_and_wraps(n)[source]#
Calculate optimal block size and number of warps for Triton kernels.
Determines the appropriate block size (as power of 2) and number of warps based on the input size, with architecture-specific adjustments for HIP.
- Parameters
n – Input size to calculate block configuration for.
- Returns
Tuple of (block_size, num_warps) optimized for the input size.
- Raises
RuntimeError – If the required block size exceeds MAX_FUSED_SIZE (65536).
Example
>>> calculate_blocksize_and_wraps(1024) (1024, 4) >>> calculate_blocksize_and_wraps(10000) (16384, 16)
- ejkernel.utils.cdiv(a: int, b: int) int[source]#
- ejkernel.utils.cdiv(a: int, b: Array) Array
- ejkernel.utils.cdiv(a: Array, b: int) Array
- ejkernel.utils.cdiv(a: Array, b: Array) Array
Ceiling division operation.
Computes the ceiling division of a by b, which is equivalent to (a + b - 1) // b.
- Parameters
a – Dividend, can be an integer or a JAX array.
b – Divisor, can be an integer or a JAX array.
- Returns
The ceiling division result with the same type as inputs.
- ejkernel.utils.dtype_index(x: array) int[source]#
Get numeric index for array dtype.
Maps JAX array dtypes to numeric indices for use in kernel dispatch and configuration.
- Parameters
x – JAX array whose dtype to index.
- Returns
1 for float16 2 for bfloat16 3 for float32
- Return type
Integer index corresponding to the dtype
- Raises
ValueError – If the dtype is not supported.
- ejkernel.utils.generate_block_indices(batch: int, num_query_blocks: int, heads: int, selected_blocks: int, block_size: int, seed: int = 42) Array[source]#
Generate random block indices for sparse attention benchmarks.
This function generates a tensor of block indices where each token attends to a random selection of previous key blocks. The indices are sorted in ascending order. Returns per-token format: each token in a query block gets the same block indices.
- Parameters
batch – Batch size.
num_query_blocks – Number of query blocks.
heads – Number of attention heads (typically kv_heads for GQA).
selected_blocks – Number of key blocks each query block should attend to.
block_size – Size of each block.
seed – Random seed for reproducibility.
- Returns
Array of shape (batch, seq_len, heads, selected_blocks) containing sorted block indices in per-token format. Positions beyond available blocks are filled with -1.
Example
>>> >>> indices = generate_block_indices(batch=2, num_query_blocks=4, heads=8, selected_blocks=2, block_size=64) >>> indices.shape (2, 256, 8, 2)
- ejkernel.utils.get_abs_err(x, y)[source]#
Calculate maximum absolute error between two arrays.
- Parameters
x – First array.
y – Second array.
- Returns
Maximum absolute difference between the arrays.
- ejkernel.utils.get_err_ratio(x, y)[source]#
Calculate relative error ratio between two arrays.
Computes the root mean square error normalized by the RMS of the reference.
- Parameters
x – Reference array.
y – Array to compare against reference.
- Returns
Relative error ratio (RMSE / RMS of reference).
- ejkernel.utils.get_input_shapes()[source]#
Generate test input shapes for benchmarking and testing.
Creates a list of input shape configurations with varying batch sizes and sequence lengths for comprehensive kernel testing.
- Returns
List of tuples containing (batch, ?, seq_len, ?, ?, ?) dimensions for testing various input configurations.
- ejkernel.utils.get_padded_headsize(size)[source]#
Calculate padded head size for optimal memory alignment.
Rounds up the head size to the next power of 2 with a minimum of 16 for better memory access patterns in attention kernels.
- Parameters
size – Original head size.
- Returns
Padded head size as the next power of 2, minimum 16.
Example
>>> get_padded_headsize(13) >>> get_padded_headsize(20)
- ejkernel.utils.get_qkv_shardings(layout: Literal['bhsd', 'bshd', 'thd'])[source]#
Get sharding specifications for attention tensors based on layout.
Returns PartitionSpecs for queries, keys, and values that are compatible with the given tensor layout and a standard (dp, fsdp, tp, sp) mesh.
- Parameters
layout – Tensor layout format: - “bhsd”: [batch, heads, seq, dim] - “bshd”: [batch, seq, heads, dim] - “thd”: [tokens, heads, dim] for packed sequences
- Returns
(q_spec, k_spec, v_spec, sq_spec, sk_spec, sv_spec) where the ‘s’ prefix indicates sequence-parallel variants.
- Return type
Tuple of 6 PartitionSpecs
- Raises
ValueError – If layout is not one of the supported formats.
- ejkernel.utils.get_segments_shardings()[source]#
Get sharding specifications for segment ID tensors.
Returns PartitionSpecs for query and key/value segment IDs, compatible with a standard (dp, fsdp, tp, sp) mesh.
- Returns
(q_spec, kv_spec, sq_spec, skv_spec) where the ‘s’ prefix indicates sequence-parallel variants.
- Return type
Tuple of 4 PartitionSpecs
- ejkernel.utils.get_sharding(arr: Array)[source]#
Gets the sharding of an array.
- Parameters
arr – Array to get sharding from.
- Returns
Sharding of the array.
- ejkernel.utils.get_stride(shape: tuple[int, ...] | jax.jaxlib._jax.Array, index=0) int[source]#
Get the stride at a specific dimension index.
- Parameters
shape – Shape of the array or the array itself.
index – The dimension index to get the stride for. Defaults to 0.
- Returns
The stride value at the specified index.
- ejkernel.utils.get_strides(shape: tuple[int, ...] | jax.jaxlib._jax.Array) tuple[int, ...][source]#
Calculates strides for a given shape.
- Parameters
shape – Shape of the array.
- Returns
Tuple of strides.
- ejkernel.utils.get_tpu_generation() int[source]#
Returns the TPU generation as an integer (e.g., 3, 4, 5). Returns 0 if no TPU is detected or if the generation cannot be determined.
- ejkernel.utils.is_cdna()[source]#
Check if running on AMD CDNA architecture.
CDNA (Compute DNA) architectures include MI100, MI200 series GPUs.
- Returns
True if running on CDNA architecture (gfx940, gfx941, etc.), False otherwise.
- ejkernel.utils.is_fp8(x)[source]#
Check if an array uses FP8 dtype and if hardware supports it.
- Parameters
x – Array to check for FP8 dtype.
- Returns
True if array is FP8 and hardware supports it, False if not FP8.
- Raises
RuntimeError – If array is FP8 but hardware doesn’t support it.
- ejkernel.utils.is_hip()[source]#
Check if running on AMD HIP backend.
- Returns
True if the current Triton target uses HIP backend, False otherwise.
- ejkernel.utils.is_rdna()[source]#
Check if running on AMD RDNA architecture.
RDNA (Radeon DNA) architectures include RX 6000, 7000 series GPUs.
- Returns
True if running on RDNA architecture (gfx1030, gfx1100, etc.), False otherwise.
- ejkernel.utils.kw_strides(x: jax.jaxlib._jax.Array | None, *stride_names: str)[source]#
Generate stride keyword arguments for kernel calls.
Creates a dictionary mapping stride names to their corresponding values for use as keyword arguments in kernel invocations.
- Parameters
x – JAX array to get strides from, or None for zero strides.
*stride_names – Names for each dimension’s stride.
- Returns
Dictionary mapping “stride_{name}” to stride values.
Example
>>> arr = jnp.ones((2, 3, 4)) >>> kw_strides(arr, 'batch', 'seq', 'head') {'stride_batch': 12, 'stride_seq': 4, 'stride_head': 1}
- ejkernel.utils.make_dummy_rpa_inputs(*, rng_seed: int = 0, num_seqs: int = 4, pages_per_seq: int = 3, page_size: int = 16, num_q_heads: int = 8, num_kv_heads: int = 2, head_dim: int = 80, kv_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, q_dtype: numpy.dtype | None = None, kv_len_max: int | None = None, total_q: int | None = None, total_num_pages: int | None = None, decode_prefill_mixed: tuple[int, int, int] | None = None)[source]#
- Returns a dict with:
queries: (sum_q, num_q_heads, head_dim) [q_dtype] keys, values: (sum_q, num_kv_heads, head_dim) [kv_dtype] kv_cache: (total_pages, page_size, x2_per_pack, pack, align(head_dim,128)) [kv_dtype] kv_lens: (num_seqs,) [int32] block_tables: (num_seqs * pages_per_seq,) [int32] query_start_loc: (num_seqs + 1,) [int32] distribution: (3,) [int32]
All constraints required by the kernel/validators are satisfied.
- Example (matches the large benchmark shapes discussed in ragged_page_attention_v3):
total_q=1024, num_q_heads=8, head_dim=128
num_kv_heads=4, kv_dtype=jnp.bfloat16 (packing=2)
page_size=64, pages_per_seq=16, total_num_pages=2**17
- ejkernel.utils.make_mesh(mesh_axis: tuple[int, int, int, int])[source]#
Create a JAX mesh with standard sharding axes.
Creates a device mesh with axes named for data parallelism (dp), fully-sharded data parallelism (fsdp), tensor parallelism (tp), and sequence parallelism (sp).
- Parameters
mesh_axis – Tuple of (dp, fsdp, tp, sp) axis sizes.
- Returns
JAX Mesh with named axes (“dp”, “fsdp”, “tp”, “sp”).
Example
>>> mesh = make_mesh((2, 1, 4, 1)) # 2 data parallel, 4 tensor parallel
- ejkernel.utils.narrow(x, dim: int, start: int, length: int)[source]#
Narrow a tensor along a specific dimension.
Extracts a contiguous slice of length length starting at start along the specified dimension, similar to PyTorch’s narrow operation.
- Parameters
x – Input array to narrow.
dim – Dimension along which to narrow.
start – Starting index of the slice.
length – Length of the slice to extract.
- Returns
Narrowed array with reduced size along the specified dimension.
Example
>>> x = jnp.arange(20).reshape(4, 5) >>> narrow(x, dim=1, start=1, length=3).shape (4, 3)
- ejkernel.utils.next_power_of_2(x: int) int[source]#
Returns the next power of two greater than or equal to x.
- Parameters
x – A non-negative integer.
- Returns
The smallest power of 2 greater than or equal to x.
- Raises
ValueError – If x is negative.
- ejkernel.utils.numeric_gen(*shape, dtype: str | numpy.dtype = <class 'jax.numpy.float16'>, method: str = 'normal')[source]#
Generate random numeric arrays for testing and debugging.
Creates random arrays using JAX’s random number generation with a global debug RNG state for reproducibility.
- Parameters
*shape – Dimensions of the array to generate.
dtype – Data type of the generated array. Defaults to float16.
method – Random generation method from jax.random. Defaults to “normal”.
- Returns
Random JAX array with specified shape and dtype.
- Raises
AssertionError – If the specified method is not available in jax.random.
Example
>>> arr = numeric_gen(2, 3, 4, dtype=jnp.float32, method="uniform") >>> arr.shape (2, 3, 4)
- ejkernel.utils.random_dense(*shape, dtype: str | numpy.dtype = <class 'jax.numpy.float16'>, limit: int | None = 1) Array[source]#
Generate a random dense array with uniform distribution.
Creates a random array with values uniformly distributed in [-limit, limit], optionally casting through bfloat16 for numerical stability.
- Parameters
*shape – Dimensions of the array to generate.
dtype – Output data type. Defaults to float16.
limit – Maximum absolute value. If None, defaults to 1/prod(shape).
- Returns
Random JAX array with specified shape and dtype.
Example
>>> arr = random_dense(2, 3, 4, dtype=jnp.float32) >>> arr.shape (2, 3, 4)
- ejkernel.utils.safe_autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, warmup=None, rep=None) Callable[[F], F][source]#
Safely apply Triton autotuning with fallback on failure.
Wraps a function with Triton’s autotuning capability, gracefully falling back to the original function if autotuning fails. This ensures kernel execution continues even if autotuning encounters issues.
- Parameters
configs – List of triton.Config objects to test during autotuning.
key – List of argument names whose values define the autotuning key.
prune_configs_by – Optional dict mapping metric names to pruning functions.
reset_to_zero – List of argument names to reset to zero between runs.
restore_value – List of argument names to restore after autotuning.
pre_hook – Optional function to call before each autotuning run.
post_hook – Optional function to call after each autotuning run.
warmup – Number of warmup runs before measuring performance.
rep – Number of repetitions for each configuration.
- Returns
A decorator that applies autotuning to the wrapped function.
Example
>>> @safe_autotune( ... configs=[triton.Config({'BLOCK_SIZE': 128})], ... key=['n_elements'] ... ) ... def kernel(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): ... pass
- ejkernel.utils.strides_from_shape(shape: tuple[int, ...]) tuple[int, ...][source]#
Calculate the strides for a contiguous array with the given shape.
- Parameters
shape – A tuple of integers representing the dimensions of an array.
- Returns
A tuple of integers representing the strides of a contiguous array.