ejkernel.ops.execution.tuning#
Autotuning and benchmarking utilities for JAX kernel optimization.
This module provides comprehensive tools for automatic hyperparameter optimization of JAX functions through systematic benchmarking and configuration testing.
- Key Features:
Profiler-based timing with Python-level fallback for accuracy
Parallel compilation and testing of hyperparameter configurations
Statistical timing analysis with outlier removal
Thread-safe caching with configurable size limits
Support for distributed computation with sharding specifications
- Classes:
Measurement: Container for a single performance measurement AutotuneData: Container for all optimization measurements and results Autotuner: Core autotuning engine for hyperparameter optimization Entry: Cache entry for storing optimal configurations AutotuningResult: Device-specific optimization results container TimingResult: Statistical timing result with mean and standard deviation FNAutotuner: Advanced class-based autotuner with profiler integration
- Functions:
autotune: Decorator for automatic hyperparameter optimization autotune_recorded: Autotune all recorded invocations benchmark: Simple function benchmarking utility
Example
>>> @autotune(hyperparams={'block_size': [64, 128, 256]})
... def matrix_op(x, y, block_size=128):
... return compute(x, y, block_size)
>>>
>>> result = matrix_op(x, y)
>>> print(matrix_op.optimal_hyperparams)
- class ejkernel.ops.execution.tuning.AutotuneData(measurements: list[ejkernel.ops.execution.tuning.Measurement])[source]#
Bases:
Generic[Cfg]Container for all optimization measurements and results.
Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results.
- Type Parameters:
Cfg: Configuration type (e.g., dict, dataclass, etc.)
- measurements#
List of all performance measurements taken
- Type
- property fastest_config: Cfg#
Get the configuration with the fastest execution time.
Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration.
- Returns
The configuration that achieved the lowest execution time
- Raises
ValueError – If no measurements are available
- measurements: list[ejkernel.ops.execution.tuning.Measurement]#
- class ejkernel.ops.execution.tuning.Autotuner(warmup=1, iters=3)[source]#
Bases:
Generic[Cfg]Core autotuning engine for hyperparameter optimization.
This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings.
- Type Parameters:
Cfg: Configuration type for hyperparameters
- warmup#
Number of warmup iterations before timing
- iters#
Number of timing iterations for measurement accuracy
- autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) AutotuneData[Cfg][source]#
Optimize hyperparameters by testing candidate configurations.
Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis.
- Parameters
make_fn – Factory function that creates a function given a config
args – Positional arguments for the function being optimized
kwargs – Keyword arguments for the function being optimized
candidates – Iterable of candidate configurations to test
- Returns
AutotuneData containing all performance measurements
- class ejkernel.ops.execution.tuning.AutotuningResult(device: str, entries: tuple[ejkernel.ops.execution.tuning.Entry, ...])[source]#
Bases:
objectResult container for device-specific optimization results.
Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays.
- device#
Device identifier these results apply to
- Type
str
- entries#
Tuple of optimization entries (operation -> config mappings)
- Type
tuple[ejkernel.ops.execution.tuning.Entry, …]
- as_overlay()[source]#
Convert results to cache overlay mapping format.
Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results.
- Returns
Dictionary mapping (device, op_id, call_key) tuples to configurations
- device: str#
- entries: tuple[ejkernel.ops.execution.tuning.Entry, ...]#
- class ejkernel.ops.execution.tuning.Entry(op_id_v: str, call_key: str, cfg: Any)[source]#
Bases:
objectCache entry for storing optimal configurations.
Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration.
- op_id_v#
Operation identifier for the optimized function
- Type
str
- call_key#
Hash key representing the function call signature
- Type
str
- cfg#
The optimal configuration found for this operation
- Type
Any
- call_key: str#
- cfg: Any#
- op_id_v: str#
- class ejkernel.ops.execution.tuning.FNAutotuner(*, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False)[source]#
Bases:
objectAdvanced class-based JAX autotuner with profiler-first timing and Python fallback.
Provides comprehensive hyperparameter optimization for JAX functions using JAX’s native profiling infrastructure when available, with automatic fallback to Python-level timing. Supports parallel compilation, statistical timing analysis, and intelligent caching of optimization results.
- Key Features:
Profiler-based timing for accurate GPU/TPU measurements
Python-level timing fallback when profiler unavailable
Parallel compilation of hyperparameter configurations
Statistical analysis with outlier removal
Thread-safe caching of optimal configurations
Optional automatic memory layout optimization
- The autotuner works by:
Generating all hyperparameter combinations
Compiling each configuration in parallel
Timing execution using profiler or Python fallback
Selecting the configuration with best performance
Caching results for future calls
- allow_fallback_timing#
Whether to use Python timing when profiler fails
- profiling_samples#
Number of profiling iterations for statistics
- must_find_profiler_fraction#
Minimum fraction of configs needing profiler results
- enable_detailed_logging#
Enable verbose error logging
- find_optimal_layouts_automatically#
Auto-discover optimal memory layouts
- max_compilation_time_seconds#
Maximum compilation time per config
- timing_warmup_iterations#
Warmup iterations before timing
- timing_rounds#
Number of timing rounds for statistics
- calls_per_round#
Function calls per timing round
- cache_size_limit#
Maximum cached optimization results
- profiler#
Profiler instance for trace capture and analysis
- PREFIX_FN = 'autotune_fn_{}'#
- decorate(fn: Callable[[...], Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#
Create a decorated version of a function with automatic hyperparameter tuning.
Wraps a function so that the first call triggers hyperparameter optimization, and subsequent calls with the same input signature use cached optimal values.
- Parameters
fn – Function to decorate with autotuning capabilities
hyperparams – Dictionary mapping hyperparameter names to lists of candidate values to test during optimization
max_workers – Maximum number of parallel workers for compilation
in_shardings – Input sharding specifications for distributed computation
out_shardings – Output sharding specifications for distributed computation
device – Target device or device string for computation
example_args – Concrete example arguments when function args are abstract
example_kws – Concrete example kwargs when function kwargs are abstract
sample_num – Maximum number of hyperparameter combinations to test
event_filter_regex – Optional regex filter for profiler events
timeout – Optional compilation timeout override
cache_key – Optional custom cache key prefix for disambiguation
- Returns
Decorated function with automatic hyperparameter optimization. The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values
- Raises
TypeError – If fn is not callable
- tune(fn: Callable[[...], Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None) tuple[collections.abc.Callable[..., Any], dict[str, Any], list[tuple[int, ejkernel.ops.execution.tuning.TimingResult]]][source]#
Tune hyperparameters for a function and return optimal configuration.
Performs comprehensive hyperparameter optimization by testing all candidate configurations in parallel, measuring their performance, and returning the best-performing configuration.
- Parameters
fn – Function to optimize (must accept hyperparameters as keyword args)
args – Positional arguments for the function (can be abstract shapes)
kwargs – Keyword arguments for the function (excluding hyperparameters)
hyperparams – Dictionary mapping hyperparameter names to lists of candidate values to test. Each combination will be evaluated.
max_workers – Maximum number of parallel workers for compilation
in_shardings – Input sharding specifications for distributed computation
out_shardings – Output sharding specifications for distributed computation
device – Target device or device string for computation
example_args – Concrete example arguments when args are abstract
example_kws – Concrete example kwargs when kwargs are abstract
sample_num – Maximum number of hyperparameter combinations to test. If fewer than total combinations, samples randomly.
event_filter_regex – Optional regex filter for profiler events
timeout – Optional compilation timeout override
- Returns
parameterized_fn: JIT-compiled function with optimal hyperparameters
optimal_hyperparams: Dictionary of optimal hyperparameter values
timing_results_sorted: List of (index, TimingResult) sorted by performance
- Return type
Tuple of (parameterized_fn, optimal_hyperparams, timing_results_sorted)
- Raises
TypeError – If fn is not callable
ValueError – If max_workers <= 0, sample_num < 0, or hyperparameters invalid
RuntimeError – If all hyperparameter configurations fail to compile
- class ejkernel.ops.execution.tuning.Measurement(cfg: Any, seconds: float)[source]#
Bases:
objectContainer for a single performance measurement.
Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization.
- cfg#
The hyperparameter configuration that was tested
- Type
Any
- seconds#
Execution time in seconds for this configuration
- Type
float
- cfg: Any#
- seconds: float#
- class ejkernel.ops.execution.tuning.TimingResult(hyperparams: dict[Any, Any], t_mean: float, t_std: float)[source]#
Bases:
objectStatistical timing result for a single hyperparameter configuration.
Stores the measured execution time statistics for a specific set of hyperparameters, including both mean and standard deviation for reliability analysis.
- hyperparams#
Dictionary of hyperparameter names to their tested values
- Type
dict[Any, Any]
- t_mean#
Mean execution time in seconds across timing iterations
- Type
float
- t_std#
Standard deviation of execution times in seconds
- Type
float
- hyperparams: dict[Any, Any]#
- t_mean: float#
- t_std: float#
- ejkernel.ops.execution.tuning.autotune(fn: collections.abc.Callable[[...], Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#
Advanced JAX function autotuning decorator with comprehensive optimization features.
A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching.
- This decorator can be used in three different ways:
@autotune() -> decorator factory with custom parameters
@autotune -> plain decorator with default parameters
autotune(fn, kw=…) -> direct function call returning wrapped function
The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature.
Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications
- Parameters
fn – Function to autotune (when used as direct call)
allow_fallback_timing – Enable Python timing fallback if profiler fails
profiling_samples – Number of profiling iterations for statistical accuracy
must_find_profiler_fraction – Minimum fraction of configs needing profiler results (0.0-1.0)
enable_detailed_logging – Enable detailed error logging with full tracebacks
timing_warmup_iterations – Warmup iterations before timing measurements
timing_rounds – Number of timing rounds for statistical analysis
calls_per_round – Function calls per timing round
profiler_prefix_filter – Event name prefix filter for profiler (default: “jit_”)
profiler_event_regex – Optional regex filter for profiler events
profiler_min_duration_ns – Minimum event duration for profiler inclusion (nanoseconds)
profiler_max_events – Maximum events per profile to prevent memory issues
profiler_verbose – Enable verbose profiler output
find_optimal_layouts_automatically – Auto-discover optimal memory layouts
max_compilation_time_seconds – Maximum compilation time per configuration
cache_size_limit – Maximum cached optimization results
hyperparams – Dictionary mapping parameter names to candidate value lists
max_workers – Maximum parallel workers for optimization
sample_num – Maximum hyperparameter combinations to test
in_shardings – Input sharding specifications for distributed computation
out_shardings – Output sharding specifications for distributed computation
device – Target device or device string for computation
example_args – Concrete example arguments for abstract input shapes
example_kws – Concrete example kwargs for abstract input values
event_filter_regex – Optional regex filter for profiler events (alias for profiler_event_regex)
timeout – Optional compilation timeout override
cache_key – Optional custom cache key prefix for disambiguation
- Returns
Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution
The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values
- Return type
When used as decorator
- Raises
TypeError – If fn is not callable (when used as direct call)
ValueError – If hyperparameter specifications are invalid
RuntimeError – If all hyperparameter configurations fail to compile
Examples
Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={‘block_size’: [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128):
return jnp.dot(a, b)
result = matrix_multiply(x, y) print(f”Optimal config: {matrix_multiply.optimal_hyperparams}”) ```
Advanced usage with custom timing and sharding: ```python @autotune(
- hyperparams={
‘chunk_size’: [32, 64, 128], ‘algorithm’: [‘parallel’, ‘sequential’]
}, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec(‘data’, None), enable_detailed_logging=True
) def compute_function(data, chunk_size=64, algorithm=’parallel’):
return processed_data
Direct call usage: ```python optimized_fn = autotune(
my_function, hyperparams={‘param’: [1, 2, 3]}, profiling_samples=5
) result = optimized_fn(input_data) ```
Using with abstract input specifications: ```python @autotune(
hyperparams={‘tile_size’: [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device=’gpu’
) def process_matrix(matrix, tile_size=32):
return result
Note
First call triggers optimization and may take longer
Subsequent calls with same input signature use cached optimal parameters
Profiler-based timing requires TensorFlow backend; falls back to Python timing
For distributed computation, ensure proper mesh configuration before use
Cache keys are based on input signatures; use cache_key for manual disambiguation
- ejkernel.ops.execution.tuning.autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1)[source]#
Autotune all kernel invocations recorded for the current device.
This function iterates through all kernel invocations that have been recorded in the global registry for the current device and runs autotuning to find optimal configurations for each unique operation/call-key combination.
- The autotuning process:
Retrieves all recorded invocations for the current device
For each recorded kernel call, prepares arguments and generates candidates
Benchmarks each candidate configuration
Stores the optimal configuration in both memory and persistent caches
Returns results as an AutotuningResult for context manager usage
- Parameters
hyperparameter_selector – ConfigSelectorChain instance with cache and persistent storage for storing optimization results
show_progress – Whether to display progress bars during optimization (currently unused, reserved for future implementation)
repetition_count – Number of times to repeat the optimization process (currently unused, reserved for future implementation)
- Returns
AutotuningResult containing all optimal configurations found. Can be used as a context manager to temporarily apply the configurations.
Example
>>> from ejkernel.ops.config import ConfigCache, ConfigSelectorChain >>> cache = ConfigCache() >>> selector = ConfigSelectorChain(cache) >>> >>> # Record invocations by running with EJKERNEL_OPS_RECORD=1 >>> # Then autotune all recorded operations: >>> result = autotune_recorded(selector) >>> with result: ... # Runs with optimized configurations ... output = my_model(input_data)
Note
Requires invocations to be previously recorded using the EJKERNEL_OPS_RECORD=1 environment variable during initial runs.
- ejkernel.ops.execution.tuning.benchmark(fn, *args, warmup=1, iters=5, **kwargs) float[source]#
Benchmark function execution time with JAX compilation.
Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments.
- Parameters
fn – Function to benchmark
*args – Positional arguments for the function
warmup – Number of warmup iterations before timing
iters – Number of timing iterations for measurement
**kwargs – Keyword arguments for the function
- Returns
Average execution time per iteration in seconds