Source code for ejkernel.ops.execution.tuning

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Autotuning and benchmarking utilities for JAX kernel optimization.

This module provides comprehensive tools for automatic hyperparameter optimization
of JAX functions through systematic benchmarking and configuration testing.

Key Features:
    - Profiler-based timing with Python-level fallback for accuracy
    - Parallel compilation and testing of hyperparameter configurations
    - Statistical timing analysis with outlier removal
    - Thread-safe caching with configurable size limits
    - Support for distributed computation with sharding specifications

Classes:
    Measurement: Container for a single performance measurement
    AutotuneData: Container for all optimization measurements and results
    Autotuner: Core autotuning engine for hyperparameter optimization
    Entry: Cache entry for storing optimal configurations
    AutotuningResult: Device-specific optimization results container
    TimingResult: Statistical timing result with mean and standard deviation
    FNAutotuner: Advanced class-based autotuner with profiler integration

Functions:
    autotune: Decorator for automatic hyperparameter optimization
    autotune_recorded: Autotune all recorded invocations
    benchmark: Simple function benchmarking utility

Example:
    >>> @autotune(hyperparams={'block_size': [64, 128, 256]})
    ... def matrix_op(x, y, block_size=128):
    ...     return compute(x, y, block_size)
    >>>
    >>> result = matrix_op(x, y)
    >>> print(matrix_op.optimal_hyperparams)
"""

from __future__ import annotations

import contextlib
import itertools
import logging
import os
import random as pyrandom
import re
import threading
import time
from collections.abc import Callable, Iterable
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Generic, TypeVar

import jax
import numpy as np
from jax import numpy as jnp
from jax import random
from jax.interpreters import pxla
from jax.sharding import PartitionSpec, Sharding, SingleDeviceSharding

from ejkernel.loggings import get_logger

from ..config.cache import overlay_cache
from ..utils.fingerprint import device_fingerprint
from .profiler import Profiler

autotune_logger = get_logger("ejkernel.autotune", "WARNING")


Cfg = TypeVar("Cfg")


[docs]@dataclass class Measurement: """Container for a single performance measurement. Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization. Attributes: cfg: The hyperparameter configuration that was tested seconds: Execution time in seconds for this configuration """ cfg: Any seconds: float
[docs]@dataclass class AutotuneData(Generic[Cfg]): """Container for all optimization measurements and results. Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results. Type Parameters: Cfg: Configuration type (e.g., dict, dataclass, etc.) Attributes: measurements: List of all performance measurements taken """ measurements: list[Measurement] @property def fastest_config(self) -> Cfg: """Get the configuration with the fastest execution time. Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration. Returns: The configuration that achieved the lowest execution time Raises: ValueError: If no measurements are available """ if not self.measurements: raise ValueError("No measurements available to determine fastest config") return min(self.measurements, key=lambda m: m.seconds).cfg
[docs]class Autotuner(Generic[Cfg]): """Core autotuning engine for hyperparameter optimization. This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings. Type Parameters: Cfg: Configuration type for hyperparameters Attributes: warmup: Number of warmup iterations before timing iters: Number of timing iterations for measurement accuracy """ def __init__(self, warmup=1, iters=3): """Initialize the autotuner with timing parameters. Args: warmup: Number of warmup calls to stabilize performance before timing iters: Number of timing iterations for statistical accuracy """ self.warmup, self.iters = warmup, iters
[docs] def autotune(self, make_fn, args, kwargs, candidates: Iterable[Cfg]) -> AutotuneData[Cfg]: """Optimize hyperparameters by testing candidate configurations. Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis. Args: make_fn: Factory function that creates a function given a config args: Positional arguments for the function being optimized kwargs: Keyword arguments for the function being optimized candidates: Iterable of candidate configurations to test Returns: AutotuneData containing all performance measurements """ measures = [] for cfg in candidates: try: fn = make_fn(cfg) c = jax.jit(fn).lower(*args, **kwargs).compile() for _ in range(self.warmup): _ = c(*args, **kwargs).block_until_ready() t0 = time.perf_counter() for _ in range(self.iters): _ = c(*args, **kwargs).block_until_ready() dt = (time.perf_counter() - t0) / self.iters measures.append(Measurement(cfg, dt)) except Exception as e: autotune_logger.warning(f"Configuration {cfg} failed: {e}") measures.append(Measurement(cfg, float("inf"))) if not measures or all(m.seconds == float("inf") for m in measures): autotune_logger.warning("All candidate configurations failed to execute; returning empty measurements.") return AutotuneData(measures)
[docs]@dataclass(frozen=True) class Entry: """Cache entry for storing optimal configurations. Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration. Attributes: op_id_v: Operation identifier for the optimized function call_key: Hash key representing the function call signature cfg: The optimal configuration found for this operation """ op_id_v: str call_key: str cfg: Any
[docs]@dataclass(frozen=True) class AutotuningResult: """Result container for device-specific optimization results. Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays. Attributes: device: Device identifier these results apply to entries: Tuple of optimization entries (operation -> config mappings) """ device: str entries: tuple[Entry, ...]
[docs] def as_overlay(self): """Convert results to cache overlay mapping format. Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results. Returns: Dictionary mapping (device, op_id, call_key) tuples to configurations """ mapping = {(self.device, e.op_id_v, e.call_key): e.cfg for e in self.entries} return mapping
def __enter__(self): """Enter context manager to apply optimization results as cache overlay. Activates the cache overlay with the optimization results, temporarily overriding any existing cache entries with the optimized configurations. Returns: Self for use in with statements """ self._ctx = overlay_cache(self.as_overlay()) self._ctx.__enter__() return self def __exit__(self, exc_type, exc, tb): """Exit context manager and restore previous cache state. Deactivates the cache overlay and restores the previous cache state, ensuring clean cleanup even if exceptions occur. Args: exc_type: Exception type (if any) exc: Exception instance (if any) tb: Exception traceback (if any) """ self._ctx.__exit__(exc_type, exc, tb) delattr(self, "_ctx")
[docs]def autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1): """Record and replay optimal hyperparameters for functions. This function provides caching and replay functionality for previously optimized hyperparameters, avoiding redundant tuning operations. Args: hyperparameter_selector: Function to select hyperparameters from recorded data show_progress: Whether to display progress bars during optimization repetition_count: Number of times to repeat the optimization process Returns: Decorated function with recorded hyperparameter optimization """ """Autotune all recorded invocations for the current device.""" from ..registry import get_invocations dev = device_fingerprint() invs = get_invocations(dev) entries = [] for op_id_v, d in invs.items(): for call_key, (kernel, args, kwargs) in d.items(): inv_args, inv_kwargs = kernel.prepare(*args, **kwargs) static_fun_kwargs = {k: v for k, v in inv_kwargs.items() if callable(v)} dyn_kwargs = {k: v for k, v in inv_kwargs.items() if k not in static_fun_kwargs} tmp_inv = type( "Tmp", (), dict( op_id=kernel.op_id, args=inv_args, kwargs=dyn_kwargs, batch_axes=None, override_cfg=None, stamp=False ), )() candidates = tuple(kernel.candidate_cfgs(tmp_inv)) def mk(c, _run=kernel.run, _static=static_fun_kwargs): return partial(_run, cfg=c, **_static) best_cfg, best_t = None, float("inf") for c in candidates: t = benchmark(mk(c), *inv_args, **dyn_kwargs) if t < best_t: best_cfg, best_t = c, t hyperparameter_selector.cache.put(dev, op_id_v, call_key, best_cfg) if hyperparameter_selector.persistent and hyperparameter_selector.persist_autotune: hyperparameter_selector.persistent.put(dev, op_id_v, call_key, best_cfg) entries.append(Entry(op_id_v, call_key, best_cfg)) return AutotuningResult(dev, tuple(entries))
def _split_static_callable_kwargs(kwargs): """Split keyword arguments into static and dynamic components. Separates callable arguments (static) from regular arguments (dynamic) for proper JAX compilation and execution. Args: kwargs: Dictionary of keyword arguments to split Returns: Tuple of (static_kwargs, dynamic_kwargs) """ static = {k: v for k, v in kwargs.items() if callable(v)} return static, {k: v for k, v in kwargs.items() if k not in static}
[docs]def benchmark(fn, *args, warmup=1, iters=5, **kwargs) -> float: """Benchmark function execution time with JAX compilation. Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments. Args: fn: Function to benchmark *args: Positional arguments for the function warmup: Number of warmup iterations before timing iters: Number of timing iterations for measurement **kwargs: Keyword arguments for the function Returns: Average execution time per iteration in seconds """ static, dyn = _split_static_callable_kwargs(kwargs) if static: def fn_wrapped(*a, _fn=fn, _static=static, **k): return _fn(*a, **(k | _static)) c = jax.jit(fn_wrapped).lower(*args, **dyn).compile() for _ in range(warmup): _ = c(*args, **dyn).block_until_ready() t0 = time.perf_counter() for _ in range(iters): _ = c(*args, **dyn).block_until_ready() return (time.perf_counter() - t0) / iters else: c = jax.jit(fn).lower(*args, **kwargs).compile() for _ in range(warmup): _ = c(*args, **kwargs).block_until_ready() t0 = time.perf_counter() for _ in range(iters): _ = c(*args, **kwargs).block_until_ready() return (time.perf_counter() - t0) / iters
[docs]@dataclass class TimingResult: hyperparams: dict[Any, Any] t_mean: float t_std: float
def _get_global_mesh(): env = pxla.thread_resources.env mesh = env.physical_mesh return None if mesh.empty else mesh def _get_default_device(): if jax.config.values["jax_default_device"] is not None: return jax.config.values["jax_default_device"] return jax.devices()[0] @contextlib.contextmanager def _suppress_stdout_stderr(): devnull = open(os.devnull, "w") stdout_fd, stderr_fd = os.dup(1), os.dup(2) try: os.dup2(devnull.fileno(), 1) os.dup2(devnull.fileno(), 2) yield finally: try: os.dup2(stdout_fd, 1) os.dup2(stderr_fd, 2) finally: os.close(stdout_fd) os.close(stderr_fd) devnull.close() def _normalize_sharding( arg: jax.Array | np.ndarray | Any, sharding_or_spec: PartitionSpec | Sharding | None, default_device: jax.Device, # type: ignore ): """Normalize sharding specification to a concrete Sharding object. Converts various sharding specifications (PartitionSpec, Sharding, None) into concrete Sharding objects that can be used for array placement. Handles both global mesh and single-device scenarios. Args: arg: Array or array-like object to be sharded sharding_or_spec: Sharding specification (PartitionSpec, Sharding, or None) default_device: Default device to use for single-device sharding Returns: Concrete Sharding object, or None for non-array arguments Raises: ValueError: If PartitionSpec is provided but no global mesh is defined """ if not isinstance(arg, jax.Array | np.ndarray): return None if isinstance(sharding_or_spec, Sharding): return sharding_or_spec global_mesh = _get_global_mesh() if isinstance(sharding_or_spec, PartitionSpec) and global_mesh is not None: return jax.NamedSharding(global_mesh, sharding_or_spec) elif isinstance(sharding_or_spec, PartitionSpec) and global_mesh is None: raise ValueError("If specifying shardings via ParitionSpec, a global mesh must be defined") else: return SingleDeviceSharding(default_device) def _ensure_dtype(dt): try: return dt.dtype if isinstance(dt, jax.Array | np.ndarray) else dt except Exception: return dt @partial(jax.jit, static_argnames=("sds", "sharding")) def _get_random_value(sds, sharding=None): if hasattr(sds, "shape") and hasattr(sds, "dtype"): dt = _ensure_dtype(sds.dtype) if jnp.issubdtype(dt, jnp.floating): return jax.jit(lambda key: random.normal(key, sds.shape, dt), out_shardings=sharding)(random.key(0)) elif jnp.issubdtype(dt, jnp.integer): return jax.jit(lambda: jnp.zeros(sds.shape, dt), out_shardings=sharding)() else: raise ValueError(f"Unsupported dtype {dt}") else: return sds def _try_hash_input(args, kws, must_be_concrete: bool = True): flat_vals, struct = jax.tree.flatten((args, kws)) all_concrete = all(jax.core.is_concrete(x) for x in flat_vals if isinstance(x, jax.Array)) if not all_concrete and must_be_concrete: return None def _get_sharding(x): try: return x.sharding except AttributeError: return jax.typeof(x).sharding def array_to_hashable(x): return x if not isinstance(x, jax.Array) else hash((jax.typeof(x), _get_sharding(x))) try: return hash((struct, tuple(array_to_hashable(x) for x in flat_vals))) except Exception: return None
[docs]class FNAutotuner: """Class-based JAX autotuner with profiler-first timing and Python fallback.""" PREFIX_FN = "autotune_fn_{}" def __init__( self, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = "jit_", profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, ): self.allow_fallback_timing = allow_fallback_timing self.profiling_samples = profiling_samples self.must_find_profiler_fraction = must_find_profiler_fraction self.enable_detailed_logging = enable_detailed_logging self.find_optimal_layouts_automatically = find_optimal_layouts_automatically self.max_compilation_time_seconds = max_compilation_time_seconds self.timing_warmup_iterations = timing_warmup_iterations self.timing_rounds = timing_rounds self.calls_per_round = calls_per_round self.cache_size_limit = cache_size_limit self.profiler = Profiler( prefix_filter=profiler_prefix_filter, event_filter_regex=profiler_event_regex, min_duration_ns=profiler_min_duration_ns, max_events_per_profile=profiler_max_events, verbose=profiler_verbose or (autotune_logger._level <= logging.INFO), ) self._cache_lock = threading.Lock() @staticmethod def _calculate_timing_score(tr: TimingResult) -> float: """Calculate a composite timing score for ranking configurations. Combines mean execution time with standard deviation to create a single score that balances speed with consistency. Penalizes configurations with high variance even if they have good mean performance. The weighting factor of 0.1 for standard deviation provides a reasonable balance between speed and stability. Args: tr: TimingResult containing mean and standard deviation Returns: Composite score (lower is better) = mean + 0.1 * std """ return tr.t_mean + 0.1 * tr.t_std def _create_parameterized_function( self, target_function: Callable[..., Any], hyperparameter_values: dict[str, Any], output_shardings: Any = None, function_id: int = 0, ) -> Callable[..., Any]: """Create a JIT-compiled function with embedded hyperparameters. Wraps the target function with specific hyperparameter values and compiles it using JAX JIT. Sets a unique name for profiler identification and optionally configures output sharding for distributed computation. The function name follows the pattern 'autotune_fn_{function_id}' for easy identification in profiler output. Args: target_function: Original function to parameterize hyperparameter_values: Dictionary of hyperparameter names to values output_shardings: Optional sharding specification for outputs function_id: Unique identifier for profiler tracking Returns: JIT-compiled function with embedded hyperparameters and unique name """ jax_compiler = partial(jax.jit, out_shardings=output_shardings) def parameterized_function(*function_args, **function_kwargs): combined_kwargs = dict(function_kwargs, **hyperparameter_values) return target_function(*function_args, **combined_kwargs) function_name = self.PREFIX_FN.format(function_id) parameterized_function.__name__ = function_name parameterized_function.__qualname__ = function_name return jax_compiler(parameterized_function) def _try_call( self, fn: Callable[..., Any], resolved_args, resolved_kwargs, compile_only: bool = False, compute_layouts: bool = False, optimal_formats: Any | None = None, timeout: float | None = None, ): """Safely compile or execute a function with comprehensive error handling. Attempts to compile or execute a function with the provided arguments, handling various failure modes gracefully. Supports layout optimization and optimal device placement for distributed computation. Args: fn: Function to compile or execute resolved_args: Positional arguments for the function resolved_kwargs: Keyword arguments for the function compile_only: If True, only compile without execution compute_layouts: If True, compute optimal memory layouts optimal_formats: Optional pre-computed optimal device formats timeout: Optional compilation timeout (currently unused) Returns: Tuple of (success_bool, error_message, optimal_input_formats) where success_bool indicates if operation succeeded, error_message contains failure details if any, and optimal_input_formats contains discovered layouts if computed """ optimal_input_formats = None try: if compile_only: if compute_layouts: def to_shape(x): return ( jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding) if isinstance(x, jax.Array) else x ) (argument_shapes, keyword_shapes) = jax.tree.map(to_shape, (resolved_args, resolved_kwargs)) try: compiled_function = jax.jit(fn).lower(*argument_shapes, **keyword_shapes).compile() optimal_input_formats = getattr(compiled_function, "input_formats", None) except Exception as compilation_error: autotune_logger.warning( f"Layout optimization failed during compilation: " f"{compilation_error.__class__.__name__}: {compilation_error}" ) optimal_input_formats = None else: _ = jax.jit(fn).lower(*resolved_args, **resolved_kwargs).compile() else: if optimal_formats is not None: def place_array_on_optimal_device(array_data, target_format): return ( jax.device_put(array_data, target_format) if isinstance(array_data, jax.Array) else array_data ) try: (optimally_placed_args, optimally_placed_kwargs) = jax.tree.map( place_array_on_optimal_device, (resolved_args, resolved_kwargs), optimal_formats ) _ = jax.block_until_ready(fn(*optimally_placed_args, **optimally_placed_kwargs)) except Exception: autotune_logger.warning( "Failed to place arrays on optimal devices - falling back to original argument placement" ) _ = jax.block_until_ready(fn(*resolved_args, **resolved_kwargs)) else: _ = jax.block_until_ready(fn(*resolved_args, **resolved_kwargs)) return True, None, optimal_input_formats except Exception as e: msg = f"{type(e).__name__}: {e!s}" if self.enable_detailed_logging: import traceback msg = traceback.format_exc() return False, msg, optimal_input_formats def _time_fn(self, target_function: Callable[[], None]) -> tuple[float, float]: """Perform high-precision Python-level timing with statistical analysis. Executes the target function multiple times with warmup iterations, measures execution times, and provides statistical analysis with outlier removal for reliable performance measurements. Uses block_until_ready() to ensure accurate timing of JAX computations. Args: target_function: Zero-argument function to time (should return JAX arrays) Returns: Tuple of (mean_time, std_time) in seconds after outlier removal, or (inf, inf) if all timing attempts failed """ def _execute_and_block(): return jax.block_until_ready(target_function()) for _ in range(self.timing_warmup_iterations): try: _execute_and_block() except Exception as warmup_error: autotune_logger.warning(f"Warmup failed: {warmup_error.__class__.__name__}: {warmup_error}") return float("inf"), float("inf") times = [] for _ in range(self.timing_rounds): t0 = time.perf_counter() try: for _ in range(self.calls_per_round): _execute_and_block() times.append(time.perf_counter() - t0) except Exception as timing_error: autotune_logger.warning(f"Timing round failed: {timing_error.__class__.__name__}: {timing_error}") times.append(float("inf")) valid = [t for t in times if not np.isinf(t)] if not valid: return float("inf"), float("inf") arr = np.array(valid) / self.calls_per_round arr = np.sort(arr) if len(arr) > 2: arr = arr[1:-1] return float(np.mean(arr)), float(np.std(arr))
[docs] def tune( self, fn: Callable[..., Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, ) -> tuple[Callable[..., Any], dict[str, Any], list[tuple[int, TimingResult]]]: """Return (parameterized_fn, optimal_hyperparams, timing_results_sorted).""" if not callable(fn): raise TypeError("fn must be callable") if max_workers <= 0: raise ValueError("max_workers must be positive") if sample_num < 0: raise ValueError("sample_num must be non-negative") if event_filter_regex is not None: self.profiler._pattern = re.compile(event_filter_regex) def _extract_array_type(x): return x if not isinstance(x, jax.Array) else jax.typeof(x) if len(args) == 0 or all(x is None or jax.core.is_concrete(x) for x in jax.tree.leaves(args)): resolved_args = args elif example_args is not None: if in_shardings is not None or device is not None: raise ValueError( "Cannot combine example_args with explicit in_shardings or device configuration. " "Example arguments should already be properly sharded and placed." ) resolved_args = example_args else: resolved_device = device if isinstance(device, jax.Device) else _get_default_device() if isinstance(resolved_device, str): resolved_device = jax.devices(resolved_device)[0] input_shardings = in_shardings if in_shardings is not None else jax.tree.map(lambda _: None, args) normalized_shardings = jax.tree.map( partial(_normalize_sharding, default_device=resolved_device), args, input_shardings, ) resolved_args = jax.tree.map( lambda x, s: _get_random_value(_extract_array_type(x), s), args, normalized_shardings ) if len(kwargs) == 0 or all(v is None or jax.core.is_concrete(v) for v in kwargs.values()): resolved_kwargs = kwargs elif example_kws is not None: resolved_kwargs = example_kws else: resolved_kwargs = jax.tree.map(lambda x: _get_random_value(_extract_array_type(x)), kwargs) hyperparams = hyperparams if hyperparams is not None else {} hyperparams_norm: dict[str, tuple[Any, ...]] = {} for k, v in hyperparams.items(): if isinstance(v, tuple | list): if len(v) == 0: raise ValueError(f"Hyperparameter '{k}' has empty list of values") hyperparams_norm[k] = tuple(v) else: hyperparams_norm[k] = (v,) if hyperparams_norm: hyperparam_settings = dict(enumerate(itertools.product(*hyperparams_norm.values()))) total_combinations = len(hyperparam_settings) if sample_num < total_combinations: if sample_num == 0: hyperparam_settings = {0: tuple()} else: sample_idx = sorted( pyrandom.sample(list(range(total_combinations)), k=min(sample_num, total_combinations)) ) hyperparam_settings_ = list(hyperparam_settings.items()) hyperparam_settings = dict([hyperparam_settings_[idx] for idx in sample_idx]) autotune_logger.info( f"Testing {len(hyperparam_settings)} hyperparameter combinations out of {total_combinations} possible" ) else: hyperparam_settings = {0: tuple()} with ThreadPoolExecutor(max_workers=max_workers) as executor: fns: dict[int, Callable[..., Any]] = {} optimal_formats: dict[int, Any] = {} for phase in range(2): compile_only = phase == 0 compute_layouts = self.find_optimal_layouts_automatically compiles = {} for i, vals in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), vals, strict=True)) fns[i] = self._create_parameterized_function(fn, hs, output_shardings=out_shardings, function_id=i) opts = dict( optimal_formats=optimal_formats.get(i, None), compute_layouts=compute_layouts, timeout=self.max_compilation_time_seconds if timeout is None else timeout, ) fut = executor.submit( self._try_call, fns[i], resolved_args, resolved_kwargs, compile_only=compile_only, **opts ) compiles[fut] = i successful = {} for fut in as_completed(compiles): status, err, optf = fut.result() if status: successful[compiles[fut]] = (status, err, optf) if compile_only and compute_layouts: for i, (_, _, optf) in successful.items(): optimal_formats[i] = optf if not successful: for fut, i in compiles.items(): _, err, _ = fut.result() autotune_logger.error( f"Hyperparameters {hyperparam_settings[i]} failed to compile with message:\n{err}" ) raise ValueError("No hyperparameters compiled successfully") hyperparam_settings = {i: hyperparam_settings[i] for i in successful.keys()} fns = {i: fns[i] for i in successful.keys()} results: dict[int, TimingResult] = {} try: args_with_device = [ next(iter(arg.devices())) for arg in jax.tree.leaves(resolved_args) if hasattr(arg, "devices") ] platform = args_with_device[0].platform if len(args_with_device) > 0 else _get_default_device().platform def _timing_closure(): settings = list(hyperparam_settings.items()) pyrandom.shuffle(settings) for i, _ in settings: self._try_call( fns[i], resolved_args, resolved_kwargs, compile_only=False, optimal_formats=optimal_formats.get(i, None), ) profiler_timings = self.profiler.profile_time_by_function_id( _timing_closure, platform, self.profiling_samples, ) fraction_measured = sum(1 for i in hyperparam_settings.keys() if i in profiler_timings) / len( hyperparam_settings ) if fraction_measured < self.must_find_profiler_fraction: missing = [i for i in hyperparam_settings.keys() if i not in profiler_timings] msg = "Could not find profiler results for some hyperparameter settings:\n" + "\n".join( f" - {i}: {hyperparam_settings[i]}" for i in missing ) raise RuntimeError(msg) for i in hyperparam_settings.keys(): if i not in profiler_timings: autotune_logger.warning( f"Could not find profiler results for hyperparameter settings: {hyperparam_settings[i]}" ) profiler_timings[i] = (float("inf"), float("inf")) for i, hs in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), hs, strict=True)) t_mean, t_std = profiler_timings[i] results[i] = TimingResult(hs, float(t_mean), float(t_std)) except Exception as e: if not self.allow_fallback_timing: raise RuntimeError( f"Need to fall back to the python-level timing, but allow_fallback_timing=False. Error: {e}" ) from None for i, hs in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), hs, strict=True)) t_mean, t_std = self._time_fn(partial(lambda fn: fn(*resolved_args, **resolved_kwargs), fns[i])) results[i] = TimingResult(hs, t_mean, t_std) results_sorted = sorted(results.items(), key=lambda x: self._calculate_timing_score(x[1])) idx, optimal_hyperparams = results_sorted[0][0], results_sorted[0][1].hyperparams return fns[idx], optimal_hyperparams, results_sorted
[docs] def decorate( self, fn: Callable[..., Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None, ): if not callable(fn): raise TypeError("fn must be callable") cache: dict[Any, tuple[dict[str, Any], list[tuple[int, TimingResult]]]] = {} cache_lock = threading.Lock() @wraps(fn) def wrapped(*args, **kws): """Wrapper function that performs autotuning on first call and caches results. On first call with a given input signature, performs hyperparameter optimization and caches the results. Subsequent calls with the same signature use the cached optimal hyperparameters. Args: *args: Positional arguments for the wrapped function **kws: Keyword arguments for the wrapped function Returns: Result of executing the function with optimal hyperparameters Raises: Exception: Any exception from the underlying function execution """ input_hash = _try_hash_input(args, kws) lookup_key = f"{cache_key}:{input_hash}" if cache_key and input_hash else input_hash with cache_lock: hit = lookup_key is not None and lookup_key in cache if hit: optimal_hyperparams, results = cache[lookup_key] else: optimal_hyperparams = None results = None if optimal_hyperparams is None: flat_vals = jax.tree.leaves((args, kws)) has_tracers = any(not jax.core.is_concrete(x) for x in flat_vals if isinstance(x, jax.Array)) if has_tracers: if wrapped.optimal_hyperparams: optimal_hyperparams = wrapped.optimal_hyperparams.copy() results = wrapped.timing_results else: optimal_hyperparams = { k: (v[0] if isinstance(v, list) else v) for k, v in (hyperparams or {}).items() } results = [] else: with jax.core.eval_context(): _, optimal_hyperparams, results = self.tune( fn, args=args, kwargs=kws, hyperparams=hyperparams, max_workers=max_workers, in_shardings=in_shardings, out_shardings=out_shardings, device=device, example_args=example_args, example_kws=example_kws, sample_num=sample_num, event_filter_regex=event_filter_regex, timeout=timeout, ) with cache_lock: if lookup_key is not None: if len(cache) >= self.cache_size_limit: oldest_key = next(iter(cache)) del cache[oldest_key] cache[lookup_key] = (optimal_hyperparams, results) wrapped.timing_results = results wrapped.optimal_hyperparams = optimal_hyperparams try: return fn(*args, **dict(kws, **optimal_hyperparams)) except Exception as e: autotune_logger.error(f"Execution failed with optimal hyperparameters {optimal_hyperparams}: {e}") raise wrapped.timing_results = [] wrapped.optimal_hyperparams = {} return wrapped
[docs]def autotune( fn: Callable[..., Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = "jit_", profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None, ): """Advanced JAX function autotuning decorator with comprehensive optimization features. A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching. This decorator can be used in three different ways: 1) @autotune() -> decorator factory with custom parameters 2) @autotune -> plain decorator with default parameters 3) autotune(fn, kw=...) -> direct function call returning wrapped function The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature. Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications Args: fn: Function to autotune (when used as direct call) allow_fallback_timing: Enable Python timing fallback if profiler fails profiling_samples: Number of profiling iterations for statistical accuracy must_find_profiler_fraction: Minimum fraction of configs needing profiler results (0.0-1.0) enable_detailed_logging: Enable detailed error logging with full tracebacks timing_warmup_iterations: Warmup iterations before timing measurements timing_rounds: Number of timing rounds for statistical analysis calls_per_round: Function calls per timing round profiler_prefix_filter: Event name prefix filter for profiler (default: "jit_") profiler_event_regex: Optional regex filter for profiler events profiler_min_duration_ns: Minimum event duration for profiler inclusion (nanoseconds) profiler_max_events: Maximum events per profile to prevent memory issues profiler_verbose: Enable verbose profiler output find_optimal_layouts_automatically: Auto-discover optimal memory layouts max_compilation_time_seconds: Maximum compilation time per configuration cache_size_limit: Maximum cached optimization results hyperparams: Dictionary mapping parameter names to candidate value lists max_workers: Maximum parallel workers for optimization sample_num: Maximum hyperparameter combinations to test in_shardings: Input sharding specifications for distributed computation out_shardings: Output sharding specifications for distributed computation device: Target device or device string for computation example_args: Concrete example arguments for abstract input shapes example_kws: Concrete example kwargs for abstract input values event_filter_regex: Optional regex filter for profiler events (alias for profiler_event_regex) timeout: Optional compilation timeout override cache_key: Optional custom cache key prefix for disambiguation Returns: When used as decorator: Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values Raises: TypeError: If fn is not callable (when used as direct call) ValueError: If hyperparameter specifications are invalid RuntimeError: If all hyperparameter configurations fail to compile Examples: Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={'block_size': [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128): return jnp.dot(a, b) result = matrix_multiply(x, y) print(f"Optimal config: {matrix_multiply.optimal_hyperparams}") ``` Advanced usage with custom timing and sharding: ```python @autotune( hyperparams={ 'chunk_size': [32, 64, 128], 'algorithm': ['parallel', 'sequential'] }, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec('data', None), enable_detailed_logging=True ) def compute_function(data, chunk_size=64, algorithm='parallel'): return processed_data ``` Direct call usage: ```python optimized_fn = autotune( my_function, hyperparams={'param': [1, 2, 3]}, profiling_samples=5 ) result = optimized_fn(input_data) ``` Using with abstract input specifications: ```python @autotune( hyperparams={'tile_size': [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device='gpu' ) def process_matrix(matrix, tile_size=32): return result ``` Note: - First call triggers optimization and may take longer - Subsequent calls with same input signature use cached optimal parameters - Profiler-based timing requires TensorFlow backend; falls back to Python timing - For distributed computation, ensure proper mesh configuration before use - Cache keys are based on input signatures; use cache_key for manual disambiguation """ tuner = FNAutotuner( allow_fallback_timing=allow_fallback_timing, profiling_samples=profiling_samples, must_find_profiler_fraction=must_find_profiler_fraction, enable_detailed_logging=enable_detailed_logging, find_optimal_layouts_automatically=find_optimal_layouts_automatically, max_compilation_time_seconds=max_compilation_time_seconds, timing_warmup_iterations=timing_warmup_iterations, timing_rounds=timing_rounds, calls_per_round=calls_per_round, cache_size_limit=cache_size_limit, profiler_prefix_filter=profiler_prefix_filter, profiler_event_regex=profiler_event_regex, profiler_min_duration_ns=profiler_min_duration_ns, profiler_max_events=profiler_max_events, profiler_verbose=profiler_verbose, ) def decorator(func: Callable[..., Any]): """Internal decorator function that applies autotuning to a target function. Args: func: Function to wrap with autotuning capabilities Returns: Function decorated with automatic hyperparameter optimization """ return tuner.decorate( func, hyperparams=hyperparams, max_workers=max_workers, in_shardings=in_shardings, out_shardings=out_shardings, device=device, example_args=example_args, example_kws=example_kws, sample_num=sample_num, event_filter_regex=event_filter_regex, timeout=timeout, cache_key=cache_key, ) if callable(fn): return decorator(fn) return decorator