ejkernel.ops.execution.tuning#

Autotuning and benchmarking utilities for JAX kernel optimization.

This module provides comprehensive tools for automatic hyperparameter optimization of JAX functions through systematic benchmarking and configuration testing.

Key Features:
  • Profiler-based timing with Python-level fallback for accuracy

  • Parallel compilation and testing of hyperparameter configurations

  • Statistical timing analysis with outlier removal

  • Thread-safe caching with configurable size limits

  • Support for distributed computation with sharding specifications

Classes:

Measurement: Container for a single performance measurement AutotuneData: Container for all optimization measurements and results Autotuner: Core autotuning engine for hyperparameter optimization Entry: Cache entry for storing optimal configurations AutotuningResult: Device-specific optimization results container TimingResult: Statistical timing result with mean and standard deviation FNAutotuner: Advanced class-based autotuner with profiler integration

Functions:

autotune: Decorator for automatic hyperparameter optimization autotune_recorded: Autotune all recorded invocations benchmark: Simple function benchmarking utility

Example

>>> @autotune(hyperparams={'block_size': [64, 128, 256]})
... def matrix_op(x, y, block_size=128):
...     return compute(x, y, block_size)
>>>
>>> result = matrix_op(x, y)
>>> print(matrix_op.optimal_hyperparams)
class ejkernel.ops.execution.tuning.AutotuneData(measurements: list[ejkernel.ops.execution.tuning.Measurement])[source]#

Bases: Generic[Cfg]

Container for all optimization measurements and results.

Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results.

Type Parameters:

Cfg: Configuration type (e.g., dict, dataclass, etc.)

measurements#

List of all performance measurements taken

Type

list[ejkernel.ops.execution.tuning.Measurement]

property fastest_config: Cfg#

Get the configuration with the fastest execution time.

Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration.

Returns

The configuration that achieved the lowest execution time

Raises

ValueError – If no measurements are available

measurements: list[ejkernel.ops.execution.tuning.Measurement]#
class ejkernel.ops.execution.tuning.Autotuner(warmup=1, iters=3)[source]#

Bases: Generic[Cfg]

Core autotuning engine for hyperparameter optimization.

This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings.

Type Parameters:

Cfg: Configuration type for hyperparameters

warmup#

Number of warmup iterations before timing

iters#

Number of timing iterations for measurement accuracy

autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) AutotuneData[Cfg][source]#

Optimize hyperparameters by testing candidate configurations.

Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis.

Parameters
  • make_fn – Factory function that creates a function given a config

  • args – Positional arguments for the function being optimized

  • kwargs – Keyword arguments for the function being optimized

  • candidates – Iterable of candidate configurations to test

Returns

AutotuneData containing all performance measurements

class ejkernel.ops.execution.tuning.AutotuningResult(device: str, entries: tuple[ejkernel.ops.execution.tuning.Entry, ...])[source]#

Bases: object

Result container for device-specific optimization results.

Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays.

device#

Device identifier these results apply to

Type

str

entries#

Tuple of optimization entries (operation -> config mappings)

Type

tuple[ejkernel.ops.execution.tuning.Entry, …]

as_overlay()[source]#

Convert results to cache overlay mapping format.

Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results.

Returns

Dictionary mapping (device, op_id, call_key) tuples to configurations

device: str#
entries: tuple[ejkernel.ops.execution.tuning.Entry, ...]#
class ejkernel.ops.execution.tuning.Entry(op_id_v: str, call_key: str, cfg: Any)[source]#

Bases: object

Cache entry for storing optimal configurations.

Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration.

op_id_v#

Operation identifier for the optimized function

Type

str

call_key#

Hash key representing the function call signature

Type

str

cfg#

The optimal configuration found for this operation

Type

Any

call_key: str#
cfg: Any#
op_id_v: str#
class ejkernel.ops.execution.tuning.FNAutotuner(*, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False)[source]#

Bases: object

Class-based JAX autotuner with profiler-first timing and Python fallback.

PREFIX_FN = 'autotune_fn_{}'#
decorate(fn: Callable[[...], Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#
tune(fn: Callable[[...], Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None) tuple[collections.abc.Callable[..., Any], dict[str, Any], list[tuple[int, ejkernel.ops.execution.tuning.TimingResult]]][source]#

Return (parameterized_fn, optimal_hyperparams, timing_results_sorted).

class ejkernel.ops.execution.tuning.Measurement(cfg: Any, seconds: float)[source]#

Bases: object

Container for a single performance measurement.

Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization.

cfg#

The hyperparameter configuration that was tested

Type

Any

seconds#

Execution time in seconds for this configuration

Type

float

cfg: Any#
seconds: float#
class ejkernel.ops.execution.tuning.TimingResult(hyperparams: 'dict[Any, Any]', t_mean: 'float', t_std: 'float')[source]#

Bases: object

hyperparams: dict[Any, Any]#
t_mean: float#
t_std: float#
ejkernel.ops.execution.tuning.autotune(fn: collections.abc.Callable[[...], Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = 'jit_', profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jaxlib._jax.Device | str | None = None, example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 9223372036854775807, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None)[source]#

Advanced JAX function autotuning decorator with comprehensive optimization features.

A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching.

This decorator can be used in three different ways:
  1. @autotune() -> decorator factory with custom parameters

  2. @autotune -> plain decorator with default parameters

  3. autotune(fn, kw=…) -> direct function call returning wrapped function

The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature.

Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications

Parameters
  • fn – Function to autotune (when used as direct call)

  • allow_fallback_timing – Enable Python timing fallback if profiler fails

  • profiling_samples – Number of profiling iterations for statistical accuracy

  • must_find_profiler_fraction – Minimum fraction of configs needing profiler results (0.0-1.0)

  • enable_detailed_logging – Enable detailed error logging with full tracebacks

  • timing_warmup_iterations – Warmup iterations before timing measurements

  • timing_rounds – Number of timing rounds for statistical analysis

  • calls_per_round – Function calls per timing round

  • profiler_prefix_filter – Event name prefix filter for profiler (default: “jit_”)

  • profiler_event_regex – Optional regex filter for profiler events

  • profiler_min_duration_ns – Minimum event duration for profiler inclusion (nanoseconds)

  • profiler_max_events – Maximum events per profile to prevent memory issues

  • profiler_verbose – Enable verbose profiler output

  • find_optimal_layouts_automatically – Auto-discover optimal memory layouts

  • max_compilation_time_seconds – Maximum compilation time per configuration

  • cache_size_limit – Maximum cached optimization results

  • hyperparams – Dictionary mapping parameter names to candidate value lists

  • max_workers – Maximum parallel workers for optimization

  • sample_num – Maximum hyperparameter combinations to test

  • in_shardings – Input sharding specifications for distributed computation

  • out_shardings – Output sharding specifications for distributed computation

  • device – Target device or device string for computation

  • example_args – Concrete example arguments for abstract input shapes

  • example_kws – Concrete example kwargs for abstract input values

  • event_filter_regex – Optional regex filter for profiler events (alias for profiler_event_regex)

  • timeout – Optional compilation timeout override

  • cache_key – Optional custom cache key prefix for disambiguation

Returns

Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution

The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values

Return type

When used as decorator

Raises
  • TypeError – If fn is not callable (when used as direct call)

  • ValueError – If hyperparameter specifications are invalid

  • RuntimeError – If all hyperparameter configurations fail to compile

Examples

Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={‘block_size’: [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128):

return jnp.dot(a, b)

result = matrix_multiply(x, y) print(f”Optimal config: {matrix_multiply.optimal_hyperparams}”) ```

Advanced usage with custom timing and sharding: ```python @autotune(

hyperparams={

‘chunk_size’: [32, 64, 128], ‘algorithm’: [‘parallel’, ‘sequential’]

}, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec(‘data’, None), enable_detailed_logging=True

) def compute_function(data, chunk_size=64, algorithm=’parallel’):

return processed_data

```

Direct call usage: ```python optimized_fn = autotune(

my_function, hyperparams={‘param’: [1, 2, 3]}, profiling_samples=5

) result = optimized_fn(input_data) ```

Using with abstract input specifications: ```python @autotune(

hyperparams={‘tile_size’: [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device=’gpu’

) def process_matrix(matrix, tile_size=32):

return result

```

Note

  • First call triggers optimization and may take longer

  • Subsequent calls with same input signature use cached optimal parameters

  • Profiler-based timing requires TensorFlow backend; falls back to Python timing

  • For distributed computation, ensure proper mesh configuration before use

  • Cache keys are based on input signatures; use cache_key for manual disambiguation

ejkernel.ops.execution.tuning.autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1)[source]#

Record and replay optimal hyperparameters for functions.

This function provides caching and replay functionality for previously optimized hyperparameters, avoiding redundant tuning operations.

Parameters
  • hyperparameter_selector – Function to select hyperparameters from recorded data

  • show_progress – Whether to display progress bars during optimization

  • repetition_count – Number of times to repeat the optimization process

Returns

Decorated function with recorded hyperparameter optimization

ejkernel.ops.execution.tuning.benchmark(fn, *args, warmup=1, iters=5, **kwargs) float[source]#

Benchmark function execution time with JAX compilation.

Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments.

Parameters
  • fn – Function to benchmark

  • *args – Positional arguments for the function

  • warmup – Number of warmup iterations before timing

  • iters – Number of timing iterations for measurement

  • **kwargs – Keyword arguments for the function

Returns

Average execution time per iteration in seconds