ejkernel.ops.execution.offline#
Offline autotuning for pre-compiled JAX computations.
This module provides utilities for autotuning kernel configurations based on lowered (compiled but not yet executed) JAX computations. It enables optimization of kernels based on the actual operations that will be executed, rather than relying on heuristics.
- Key Functions:
autotune_lowered: Autotune all ejkernel operations found in a lowered computation
- The offline tuning workflow:
Lower a JAX function to get compiled representation
Extract ejkernel operation labels from the HLO code
Match labels to recorded invocations in the registry
Run autotuning for each matched operation
Store optimal configurations in cache
Example
>>> lowered = jax.jit(my_function).lower(example_args)
>>> result = autotune_lowered(selector, lowered)
>>> with result: # Apply optimized configurations
... output = jax.jit(my_function)(real_args)
- ejkernel.ops.execution.offline.autotune_lowered(selector: ConfigSelectorChain, lowered) AutotuningResult[source]#
Autotune all ejkernel operations found in a lowered JAX computation.
Analyzes a lowered JAX computation to identify all ejkernel operations, then runs autotuning for each operation using recorded invocations from the global registry.
- Parameters
selector – ConfigSelectorChain with tuner and cache for optimization
lowered – JAX lowered computation containing ejkernel operations
- Returns
AutotuningResult containing optimal configurations for all tuned operations. Can be used as a context manager to temporarily apply the configurations.
Example
>>> lowered = jax.jit(my_model).lower(example_input) >>> result = autotune_lowered(selector, lowered) >>> with result: ... # Runs with optimized configurations ... output = jax.jit(my_model)(real_input)
Note
Only operations that have been previously recorded in the invocation registry will be tuned. Use EJKERNEL_OPS_RECORD=1 environment variable to enable invocation recording during initial runs.