ejkernel.ops.execution.batch

ejkernel.ops.execution.batch#

Batch processing utilities for vectorized and parallel execution.

This module provides utilities for efficiently executing kernels over batched data using JAX’s vmap and pmap transformations while maintaining the benefits of automatic configuration selection.

Key Functions:

vmap_with_config: Vectorized execution with shared configuration selection pmap_with_config: Parallel execution across devices with shared configuration

The Challenge:

JAX’s vmap and pmap transformations apply functions element-wise across batch dimensions. However, configuration selection can be expensive and shouldn’t be repeated for every element in a batch.

The Solution:

These utilities select a configuration once using a representative sample (typically the first element), then apply that configuration to all elements in the batch using the appropriate JAX transformation.

Benefits:
  • Amortized configuration selection cost across batch elements

  • Consistent configuration for all elements in a batch

  • Full compatibility with JAX transformations

  • Automatic handling of different input axis specifications

  • Support for both CPU vectorization (vmap) and multi-device parallelism (pmap)

Example Usage:
>>>
>>> vmapped_fn = vmap_with_config(executor, kernel, in_axes=0)
>>> batch_result = vmapped_fn(batch_input)
>>>
>>>
>>> pmapped_fn = pmap_with_config(executor, kernel, in_axes=0)
>>> device_result = pmapped_fn(device_sharded_input)

Note

Configuration selection uses the first element along the specified axis as a representative sample. This assumes that optimal configuration is consistent across the batch, which is typically true for homogeneous data.

ejkernel.ops.execution.batch.pmap_with_config(executor, kernel, in_axes=0, axis_name='devices')[source]#

Parallel execution across devices with shared configuration selection.

Creates a parallel version of kernel execution for multi-device computation where configuration is selected once using data from the first device, then applied to all devices via jax.pmap.

This enables efficient multi-device execution while avoiding redundant configuration selection on each device.

Parameters
  • executor – Executor instance for running the kernel

  • kernel – Kernel to execute in parallel across devices

  • in_axes – Input axes specification for pmap (default: 0) - int: Same axis for all arguments (typically device axis) - tuple/list: Per-argument axis specification - None: Broadcast argument to all devices

  • axis_name – Name for the parallel axis (default: “devices”) Used for collective operations and debugging

Returns

Function that performs parallel execution with shared config selection

Example

>>>
>>> cache = ConfigCache()
>>> selector = ConfigSelectorChain(cache)
>>> executor = Executor(selector)
>>>
>>>
>>> pmapped_matmul = pmap_with_config(executor, matmul_kernel, in_axes=0)
>>>
>>>
>>> devices = jax.devices()
>>> x_sharded = jax.device_put_sharded([x1, x2, x3, x4], devices)
>>> y_sharded = jax.device_put_sharded([y1, y2, y3, y4], devices)
>>>
>>>
>>> result_sharded = pmapped_matmul(x_sharded, y_sharded)

Note

Configuration selection uses data from the first device (index 0) as a representative sample. This assumes optimal configuration is consistent across devices, which is typically true for homogeneous hardware setups.

ejkernel.ops.execution.batch.vmap_with_config(executor, kernel, in_axes=0) Callable[[...], Any][source]#

Vectorized execution with shared configuration selection.

Creates a vectorized version of kernel execution where configuration is selected once using a representative sample, then applied to all elements in the batch via jax.vmap.

This approach significantly reduces overhead compared to selecting configuration for each batch element individually, while maintaining optimal performance.

Parameters
  • executor – Executor instance for running the kernel

  • kernel – Kernel to execute vectorially

  • in_axes – Input axes specification for vmap (default: 0) - int: Same axis for all arguments - tuple/list: Per-argument axis specification - None: Broadcast argument (no vectorization)

Returns

Function that performs vectorized execution with shared config selection

Example

>>>
>>> cache = ConfigCache()
>>> selector = ConfigSelectorChain(cache)
>>> executor = Executor(selector)
>>>
>>>
>>> vmapped_matmul = vmap_with_config(executor, matmul_kernel, in_axes=0)
>>>
>>>
>>> batch_x = jnp.array([x1, x2, x3, ...])
>>> batch_y = jnp.array([y1, y2, y3, ...])
>>> batch_result = vmapped_matmul(batch_x, batch_y)

Note

The representative sample is obtained by taking the first element along each specified axis. This sample is used for configuration selection but not included in the final batch computation.