ejkernel.ops.execution.offline

ejkernel.ops.execution.offline#

Offline autotuning for pre-compiled JAX computations.

This module provides utilities for autotuning kernel configurations based on lowered (compiled but not yet executed) JAX computations. It enables optimization of kernels based on the actual operations that will be executed, rather than relying on heuristics.

Key Functions:

autotune_lowered: Autotune all ejkernel operations found in a lowered computation

The offline tuning workflow:
  1. Lower a JAX function to get compiled representation

  2. Extract ejkernel operation labels from the HLO code

  3. Match labels to recorded invocations in the registry

  4. Run autotuning for each matched operation

  5. Store optimal configurations in cache

Example

>>> lowered = jax.jit(my_function).lower(example_args)
>>> result = autotune_lowered(selector, lowered)
>>> with result:  # Apply optimized configurations
...     output = jax.jit(my_function)(real_args)
ejkernel.ops.execution.offline.autotune_lowered(selector: ConfigSelectorChain, lowered) AutotuningResult[source]#

Autotune all ejkernel operations found in a lowered JAX computation.

Analyzes a lowered JAX computation to identify all ejkernel operations, then runs autotuning for each operation using recorded invocations from the global registry.

Parameters
  • selector – ConfigSelectorChain with tuner and cache for optimization

  • lowered – JAX lowered computation containing ejkernel operations

Returns

AutotuningResult containing optimal configurations for all tuned operations. Can be used as a context manager to temporarily apply the configurations.

Example

>>> lowered = jax.jit(my_model).lower(example_input)
>>> result = autotune_lowered(selector, lowered)
>>> with result:
...     # Runs with optimized configurations
...     output = jax.jit(my_model)(real_input)

Note

Only operations that have been previously recorded in the invocation registry will be tuned. Use EJKERNEL_OPS_RECORD=1 environment variable to enable invocation recording during initial runs.