ejkernel.ops.execution.tuning#
Autotuning and benchmarking utilities for JAX kernel optimization.
This module provides comprehensive tools for automatic hyperparameter optimization of JAX functions through systematic benchmarking and configuration testing.
- Key Features:
Profiler-based timing with Python-level fallback for accuracy
Parallel compilation and testing of hyperparameter configurations
Statistical timing analysis with outlier removal
Thread-safe caching with configurable size limits
Support for distributed computation with sharding specifications
- Classes:
Measurement: Container for a single performance measurement AutotuneData: Container for all optimization measurements and results Autotuner: Core autotuning engine for hyperparameter optimization Entry: Cache entry for storing optimal configurations AutotuningResult: Device-specific optimization results container TimingResult: Statistical timing result with mean and standard deviation FNAutotuner: Advanced class-based autotuner with profiler integration
- Functions:
autotune: Decorator for automatic hyperparameter optimization autotune_recorded: Autotune all recorded invocations benchmark: Simple function benchmarking utility
Example
>>> @autotune(hyperparams={'block_size': [64, 128, 256]})
... def matrix_op(x, y, block_size=128):
... return compute(x, y, block_size)
>>>
>>> result = matrix_op(x, y)
>>> print(matrix_op.optimal_hyperparams)
- class ejkernel.ops.execution.tuning.AutotuneData(measurements: list[ejkernel.ops.execution.tuning.Measurement])[source]#
Bases:
Generic[Cfg]Container for all optimization measurements and results.
Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results.
- Type Parameters:
Cfg: Configuration type (e.g., dict, dataclass, etc.)
- measurements#
List of all performance measurements taken
- Type
- property fastest_config: Cfg#
Get the configuration with the fastest execution time.
Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration.
- Returns
The configuration that achieved the lowest execution time
- Raises
ValueError – If no measurements are available
- measurements: list[ejkernel.ops.execution.tuning.Measurement]#
- class ejkernel.ops.execution.tuning.Autotuner(warmup=1, iters=3)[source]#
Bases:
Generic[Cfg]Core autotuning engine for hyperparameter optimization.
This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings.
- Type Parameters:
Cfg: Configuration type for hyperparameters
- warmup#
Number of warmup iterations before timing
- iters#
Number of timing iterations for measurement accuracy
- autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) AutotuneData[Cfg][source]#
Optimize hyperparameters by testing candidate configurations.
Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis.
- Parameters
make_fn – Factory function that creates a function given a config
args – Positional arguments for the function being optimized
kwargs – Keyword arguments for the function being optimized
candidates – Iterable of candidate configurations to test
- Returns
AutotuneData containing all performance measurements
- class ejkernel.ops.execution.tuning.AutotuningResult(device: str, entries: tuple[ejkernel.ops.execution.tuning.Entry, ...])[source]#
Bases:
objectResult container for device-specific optimization results.
Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays.
- device#
Device identifier these results apply to
- Type
str
- entries#
Tuple of optimization entries (operation -> config mappings)
- Type
tuple[ejkernel.ops.execution.tuning.Entry, …]
- as_overlay()[source]#
Convert results to cache overlay mapping format.
Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results.
- Returns
Dictionary mapping (device, op_id, call_key) tuples to configurations
- device: str#
- entries: tuple[ejkernel.ops.execution.tuning.Entry, ...]#
- class ejkernel.ops.execution.tuning.Entry(op_id_v: str, call_key: str, cfg: Any)[source]#
Bases:
objectCache entry for storing optimal configurations.
Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration.
- op_id_v#
Operation identifier for the optimized function
- Type
str
- call_key#
Hash key representing the function call signature
- Type
str
- cfg#
The optimal configuration found for this operation
- Type
Any
- call_key: str#
- cfg: Any#
- op_id_v: str#
- class ejkernel.ops.execution.tuning.FNAutotuner(*, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False)[source]#
Bases:
objectClass-based JAX autotuner with profiler-first timing and Python fallback.
- PREFIX_FN = 'autotune_fn_{}'#
- decorate(fn: Callable[[...], Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#
- tune(fn: Callable[[...], Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None) tuple[collections.abc.Callable[..., Any], dict[str, Any], list[tuple[int, ejkernel.ops.execution.tuning.TimingResult]]][source]#
Return (parameterized_fn, optimal_hyperparams, timing_results_sorted).
- class ejkernel.ops.execution.tuning.Measurement(cfg: Any, seconds: float)[source]#
Bases:
objectContainer for a single performance measurement.
Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization.
- cfg#
The hyperparameter configuration that was tested
- Type
Any
- seconds#
Execution time in seconds for this configuration
- Type
float
- cfg: Any#
- seconds: float#
- class ejkernel.ops.execution.tuning.TimingResult(hyperparams: 'dict[Any, Any]', t_mean: 'float', t_std: 'float')[source]#
Bases:
object- hyperparams: dict[Any, Any]#
- t_mean: float#
- t_std: float#
- ejkernel.ops.execution.tuning.autotune(fn: collections.abc.Callable[[...], Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#
Advanced JAX function autotuning decorator with comprehensive optimization features.
A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching.
- This decorator can be used in three different ways:
@autotune() -> decorator factory with custom parameters
@autotune -> plain decorator with default parameters
autotune(fn, kw=…) -> direct function call returning wrapped function
The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature.
Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications
- Parameters
fn – Function to autotune (when used as direct call)
allow_fallback_timing – Enable Python timing fallback if profiler fails
profiling_samples – Number of profiling iterations for statistical accuracy
must_find_profiler_fraction – Minimum fraction of configs needing profiler results (0.0-1.0)
enable_detailed_logging – Enable detailed error logging with full tracebacks
timing_warmup_iterations – Warmup iterations before timing measurements
timing_rounds – Number of timing rounds for statistical analysis
calls_per_round – Function calls per timing round
profiler_prefix_filter – Event name prefix filter for profiler (default: “jit_”)
profiler_event_regex – Optional regex filter for profiler events
profiler_min_duration_ns – Minimum event duration for profiler inclusion (nanoseconds)
profiler_max_events – Maximum events per profile to prevent memory issues
profiler_verbose – Enable verbose profiler output
find_optimal_layouts_automatically – Auto-discover optimal memory layouts
max_compilation_time_seconds – Maximum compilation time per configuration
cache_size_limit – Maximum cached optimization results
hyperparams – Dictionary mapping parameter names to candidate value lists
max_workers – Maximum parallel workers for optimization
sample_num – Maximum hyperparameter combinations to test
in_shardings – Input sharding specifications for distributed computation
out_shardings – Output sharding specifications for distributed computation
device – Target device or device string for computation
example_args – Concrete example arguments for abstract input shapes
example_kws – Concrete example kwargs for abstract input values
event_filter_regex – Optional regex filter for profiler events (alias for profiler_event_regex)
timeout – Optional compilation timeout override
cache_key – Optional custom cache key prefix for disambiguation
- Returns
Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution
The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values
- Return type
When used as decorator
- Raises
TypeError – If fn is not callable (when used as direct call)
ValueError – If hyperparameter specifications are invalid
RuntimeError – If all hyperparameter configurations fail to compile
Examples
Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={‘block_size’: [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128):
return jnp.dot(a, b)
result = matrix_multiply(x, y) print(f”Optimal config: {matrix_multiply.optimal_hyperparams}”) ```
Advanced usage with custom timing and sharding: ```python @autotune(
- hyperparams={
‘chunk_size’: [32, 64, 128], ‘algorithm’: [‘parallel’, ‘sequential’]
}, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec(‘data’, None), enable_detailed_logging=True
) def compute_function(data, chunk_size=64, algorithm=’parallel’):
return processed_data
Direct call usage: ```python optimized_fn = autotune(
my_function, hyperparams={‘param’: [1, 2, 3]}, profiling_samples=5
) result = optimized_fn(input_data) ```
Using with abstract input specifications: ```python @autotune(
hyperparams={‘tile_size’: [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device=’gpu’
) def process_matrix(matrix, tile_size=32):
return result
Note
First call triggers optimization and may take longer
Subsequent calls with same input signature use cached optimal parameters
Profiler-based timing requires TensorFlow backend; falls back to Python timing
For distributed computation, ensure proper mesh configuration before use
Cache keys are based on input signatures; use cache_key for manual disambiguation
- ejkernel.ops.execution.tuning.autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1)[source]#
Record and replay optimal hyperparameters for functions.
This function provides caching and replay functionality for previously optimized hyperparameters, avoiding redundant tuning operations.
- Parameters
hyperparameter_selector – Function to select hyperparameters from recorded data
show_progress – Whether to display progress bars during optimization
repetition_count – Number of times to repeat the optimization process
- Returns
Decorated function with recorded hyperparameter optimization
- ejkernel.ops.execution.tuning.benchmark(fn, *args, warmup=1, iters=5, **kwargs) float[source]#
Benchmark function execution time with JAX compilation.
Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments.
- Parameters
fn – Function to benchmark
*args – Positional arguments for the function
warmup – Number of warmup iterations before timing
iters – Number of timing iterations for measurement
**kwargs – Keyword arguments for the function
- Returns
Average execution time per iteration in seconds