ejkernel.ops.execution.tuning#

Autotuning and benchmarking utilities for JAX kernel optimization.

This module provides comprehensive tools for automatic hyperparameter optimization of JAX functions through systematic benchmarking and configuration testing.

Key Features:
  • Profiler-based timing with Python-level fallback for accuracy

  • Parallel compilation and testing of hyperparameter configurations

  • Statistical timing analysis with outlier removal

  • Thread-safe caching with configurable size limits

  • Support for distributed computation with sharding specifications

Classes:

Measurement: Container for a single performance measurement AutotuneData: Container for all optimization measurements and results Autotuner: Core autotuning engine for hyperparameter optimization Entry: Cache entry for storing optimal configurations AutotuningResult: Device-specific optimization results container TimingResult: Statistical timing result with mean and standard deviation FNAutotuner: Advanced class-based autotuner with profiler integration

Functions:

autotune: Decorator for automatic hyperparameter optimization autotune_recorded: Autotune all recorded invocations benchmark: Simple function benchmarking utility

Example

>>> @autotune(hyperparams={'block_size': [64, 128, 256]})
... def matrix_op(x, y, block_size=128):
...     return compute(x, y, block_size)
>>>
>>> result = matrix_op(x, y)
>>> print(matrix_op.optimal_hyperparams)
class ejkernel.ops.execution.tuning.AutotuneData(measurements: list[ejkernel.ops.execution.tuning.Measurement])[source]#

Bases: Generic[Cfg]

Container for all optimization measurements and results.

Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results.

Type Parameters:

Cfg: Configuration type (e.g., dict, dataclass, etc.)

measurements#

List of all performance measurements taken

Type

list[ejkernel.ops.execution.tuning.Measurement]

property fastest_config: Cfg#

Get the configuration with the fastest execution time.

Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration.

Returns

The configuration that achieved the lowest execution time

Raises

ValueError – If no measurements are available

measurements: list[ejkernel.ops.execution.tuning.Measurement]#
class ejkernel.ops.execution.tuning.Autotuner(warmup=1, iters=3)[source]#

Bases: Generic[Cfg]

Core autotuning engine for hyperparameter optimization.

This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings.

Type Parameters:

Cfg: Configuration type for hyperparameters

warmup#

Number of warmup iterations before timing

iters#

Number of timing iterations for measurement accuracy

autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) AutotuneData[Cfg][source]#

Optimize hyperparameters by testing candidate configurations.

Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis.

Parameters
  • make_fn – Factory function that creates a function given a config

  • args – Positional arguments for the function being optimized

  • kwargs – Keyword arguments for the function being optimized

  • candidates – Iterable of candidate configurations to test

Returns

AutotuneData containing all performance measurements

class ejkernel.ops.execution.tuning.AutotuningResult(device: str, entries: tuple[ejkernel.ops.execution.tuning.Entry, ...])[source]#

Bases: object

Result container for device-specific optimization results.

Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays.

device#

Device identifier these results apply to

Type

str

entries#

Tuple of optimization entries (operation -> config mappings)

Type

tuple[ejkernel.ops.execution.tuning.Entry, …]

as_overlay()[source]#

Convert results to cache overlay mapping format.

Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results.

Returns

Dictionary mapping (device, op_id, call_key) tuples to configurations

device: str#
entries: tuple[ejkernel.ops.execution.tuning.Entry, ...]#
class ejkernel.ops.execution.tuning.Entry(op_id_v: str, call_key: str, cfg: Any)[source]#

Bases: object

Cache entry for storing optimal configurations.

Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration.

op_id_v#

Operation identifier for the optimized function

Type

str

call_key#

Hash key representing the function call signature

Type

str

cfg#

The optimal configuration found for this operation

Type

Any

call_key: str#
cfg: Any#
op_id_v: str#
class ejkernel.ops.execution.tuning.FNAutotuner(*, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False)[source]#

Bases: object

Advanced class-based JAX autotuner with profiler-first timing and Python fallback.

Provides comprehensive hyperparameter optimization for JAX functions using JAX’s native profiling infrastructure when available, with automatic fallback to Python-level timing. Supports parallel compilation, statistical timing analysis, and intelligent caching of optimization results.

Key Features:
  • Profiler-based timing for accurate GPU/TPU measurements

  • Python-level timing fallback when profiler unavailable

  • Parallel compilation of hyperparameter configurations

  • Statistical analysis with outlier removal

  • Thread-safe caching of optimal configurations

  • Optional automatic memory layout optimization

The autotuner works by:
  1. Generating all hyperparameter combinations

  2. Compiling each configuration in parallel

  3. Timing execution using profiler or Python fallback

  4. Selecting the configuration with best performance

  5. Caching results for future calls

allow_fallback_timing#

Whether to use Python timing when profiler fails

profiling_samples#

Number of profiling iterations for statistics

must_find_profiler_fraction#

Minimum fraction of configs needing profiler results

enable_detailed_logging#

Enable verbose error logging

find_optimal_layouts_automatically#

Auto-discover optimal memory layouts

max_compilation_time_seconds#

Maximum compilation time per config

timing_warmup_iterations#

Warmup iterations before timing

timing_rounds#

Number of timing rounds for statistics

calls_per_round#

Function calls per timing round

cache_size_limit#

Maximum cached optimization results

profiler#

Profiler instance for trace capture and analysis

PREFIX_FN = 'autotune_fn_{}'#
decorate(fn: Callable[[...], Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#

Create a decorated version of a function with automatic hyperparameter tuning.

Wraps a function so that the first call triggers hyperparameter optimization, and subsequent calls with the same input signature use cached optimal values.

Parameters
  • fn – Function to decorate with autotuning capabilities

  • hyperparams – Dictionary mapping hyperparameter names to lists of candidate values to test during optimization

  • max_workers – Maximum number of parallel workers for compilation

  • in_shardings – Input sharding specifications for distributed computation

  • out_shardings – Output sharding specifications for distributed computation

  • device – Target device or device string for computation

  • example_args – Concrete example arguments when function args are abstract

  • example_kws – Concrete example kwargs when function kwargs are abstract

  • sample_num – Maximum number of hyperparameter combinations to test

  • event_filter_regex – Optional regex filter for profiler events

  • timeout – Optional compilation timeout override

  • cache_key – Optional custom cache key prefix for disambiguation

Returns

Decorated function with automatic hyperparameter optimization. The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values

Raises

TypeError – If fn is not callable

tune(fn: Callable[[...], Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None) tuple[collections.abc.Callable[..., Any], dict[str, Any], list[tuple[int, ejkernel.ops.execution.tuning.TimingResult]]][source]#

Tune hyperparameters for a function and return optimal configuration.

Performs comprehensive hyperparameter optimization by testing all candidate configurations in parallel, measuring their performance, and returning the best-performing configuration.

Parameters
  • fn – Function to optimize (must accept hyperparameters as keyword args)

  • args – Positional arguments for the function (can be abstract shapes)

  • kwargs – Keyword arguments for the function (excluding hyperparameters)

  • hyperparams – Dictionary mapping hyperparameter names to lists of candidate values to test. Each combination will be evaluated.

  • max_workers – Maximum number of parallel workers for compilation

  • in_shardings – Input sharding specifications for distributed computation

  • out_shardings – Output sharding specifications for distributed computation

  • device – Target device or device string for computation

  • example_args – Concrete example arguments when args are abstract

  • example_kws – Concrete example kwargs when kwargs are abstract

  • sample_num – Maximum number of hyperparameter combinations to test. If fewer than total combinations, samples randomly.

  • event_filter_regex – Optional regex filter for profiler events

  • timeout – Optional compilation timeout override

Returns

  • parameterized_fn: JIT-compiled function with optimal hyperparameters

  • optimal_hyperparams: Dictionary of optimal hyperparameter values

  • timing_results_sorted: List of (index, TimingResult) sorted by performance

Return type

Tuple of (parameterized_fn, optimal_hyperparams, timing_results_sorted)

Raises
  • TypeError – If fn is not callable

  • ValueError – If max_workers <= 0, sample_num < 0, or hyperparameters invalid

  • RuntimeError – If all hyperparameter configurations fail to compile

class ejkernel.ops.execution.tuning.Measurement(cfg: Any, seconds: float)[source]#

Bases: object

Container for a single performance measurement.

Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization.

cfg#

The hyperparameter configuration that was tested

Type

Any

seconds#

Execution time in seconds for this configuration

Type

float

cfg: Any#
seconds: float#
class ejkernel.ops.execution.tuning.TimingResult(hyperparams: dict[Any, Any], t_mean: float, t_std: float)[source]#

Bases: object

Statistical timing result for a single hyperparameter configuration.

Stores the measured execution time statistics for a specific set of hyperparameters, including both mean and standard deviation for reliability analysis.

hyperparams#

Dictionary of hyperparameter names to their tested values

Type

dict[Any, Any]

t_mean#

Mean execution time in seconds across timing iterations

Type

float

t_std#

Standard deviation of execution times in seconds

Type

float

hyperparams: dict[Any, Any]#
t_mean: float#
t_std: float#
ejkernel.ops.execution.tuning.autotune(fn: collections.abc.Callable[[...], Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#

Advanced JAX function autotuning decorator with comprehensive optimization features.

A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching.

This decorator can be used in three different ways:
  1. @autotune() -> decorator factory with custom parameters

  2. @autotune -> plain decorator with default parameters

  3. autotune(fn, kw=…) -> direct function call returning wrapped function

The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature.

Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications

Parameters
  • fn – Function to autotune (when used as direct call)

  • allow_fallback_timing – Enable Python timing fallback if profiler fails

  • profiling_samples – Number of profiling iterations for statistical accuracy

  • must_find_profiler_fraction – Minimum fraction of configs needing profiler results (0.0-1.0)

  • enable_detailed_logging – Enable detailed error logging with full tracebacks

  • timing_warmup_iterations – Warmup iterations before timing measurements

  • timing_rounds – Number of timing rounds for statistical analysis

  • calls_per_round – Function calls per timing round

  • profiler_prefix_filter – Event name prefix filter for profiler (default: “jit_”)

  • profiler_event_regex – Optional regex filter for profiler events

  • profiler_min_duration_ns – Minimum event duration for profiler inclusion (nanoseconds)

  • profiler_max_events – Maximum events per profile to prevent memory issues

  • profiler_verbose – Enable verbose profiler output

  • find_optimal_layouts_automatically – Auto-discover optimal memory layouts

  • max_compilation_time_seconds – Maximum compilation time per configuration

  • cache_size_limit – Maximum cached optimization results

  • hyperparams – Dictionary mapping parameter names to candidate value lists

  • max_workers – Maximum parallel workers for optimization

  • sample_num – Maximum hyperparameter combinations to test

  • in_shardings – Input sharding specifications for distributed computation

  • out_shardings – Output sharding specifications for distributed computation

  • device – Target device or device string for computation

  • example_args – Concrete example arguments for abstract input shapes

  • example_kws – Concrete example kwargs for abstract input values

  • event_filter_regex – Optional regex filter for profiler events (alias for profiler_event_regex)

  • timeout – Optional compilation timeout override

  • cache_key – Optional custom cache key prefix for disambiguation

Returns

Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution

The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values

Return type

When used as decorator

Raises
  • TypeError – If fn is not callable (when used as direct call)

  • ValueError – If hyperparameter specifications are invalid

  • RuntimeError – If all hyperparameter configurations fail to compile

Examples

Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={‘block_size’: [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128):

return jnp.dot(a, b)

result = matrix_multiply(x, y) print(f”Optimal config: {matrix_multiply.optimal_hyperparams}”) ```

Advanced usage with custom timing and sharding: ```python @autotune(

hyperparams={

‘chunk_size’: [32, 64, 128], ‘algorithm’: [‘parallel’, ‘sequential’]

}, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec(‘data’, None), enable_detailed_logging=True

) def compute_function(data, chunk_size=64, algorithm=’parallel’):

return processed_data

```

Direct call usage: ```python optimized_fn = autotune(

my_function, hyperparams={‘param’: [1, 2, 3]}, profiling_samples=5

) result = optimized_fn(input_data) ```

Using with abstract input specifications: ```python @autotune(

hyperparams={‘tile_size’: [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device=’gpu’

) def process_matrix(matrix, tile_size=32):

return result

```

Note

  • First call triggers optimization and may take longer

  • Subsequent calls with same input signature use cached optimal parameters

  • Profiler-based timing requires TensorFlow backend; falls back to Python timing

  • For distributed computation, ensure proper mesh configuration before use

  • Cache keys are based on input signatures; use cache_key for manual disambiguation

ejkernel.ops.execution.tuning.autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1)[source]#

Autotune all kernel invocations recorded for the current device.

This function iterates through all kernel invocations that have been recorded in the global registry for the current device and runs autotuning to find optimal configurations for each unique operation/call-key combination.

The autotuning process:
  1. Retrieves all recorded invocations for the current device

  2. For each recorded kernel call, prepares arguments and generates candidates

  3. Benchmarks each candidate configuration

  4. Stores the optimal configuration in both memory and persistent caches

  5. Returns results as an AutotuningResult for context manager usage

Parameters
  • hyperparameter_selector – ConfigSelectorChain instance with cache and persistent storage for storing optimization results

  • show_progress – Whether to display progress bars during optimization (currently unused, reserved for future implementation)

  • repetition_count – Number of times to repeat the optimization process (currently unused, reserved for future implementation)

Returns

AutotuningResult containing all optimal configurations found. Can be used as a context manager to temporarily apply the configurations.

Example

>>> from ejkernel.ops.config import ConfigCache, ConfigSelectorChain
>>> cache = ConfigCache()
>>> selector = ConfigSelectorChain(cache)
>>>
>>> # Record invocations by running with EJKERNEL_OPS_RECORD=1
>>> # Then autotune all recorded operations:
>>> result = autotune_recorded(selector)
>>> with result:
...     # Runs with optimized configurations
...     output = my_model(input_data)

Note

Requires invocations to be previously recorded using the EJKERNEL_OPS_RECORD=1 environment variable during initial runs.

ejkernel.ops.execution.tuning.benchmark(fn, *args, warmup=1, iters=5, **kwargs) float[source]#

Benchmark function execution time with JAX compilation.

Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments.

Parameters
  • fn – Function to benchmark

  • *args – Positional arguments for the function

  • warmup – Number of warmup iterations before timing

  • iters – Number of timing iterations for measurement

  • **kwargs – Keyword arguments for the function

Returns

Average execution time per iteration in seconds