Source code for ejkernel.ops.execution.tuning

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Autotuning and benchmarking utilities for JAX kernel optimization.

This module provides comprehensive tools for automatic hyperparameter optimization
of JAX functions through systematic benchmarking and configuration testing.

Key Features:
    - Profiler-based timing with Python-level fallback for accuracy
    - Parallel compilation and testing of hyperparameter configurations
    - Statistical timing analysis with outlier removal
    - Thread-safe caching with configurable size limits
    - Support for distributed computation with sharding specifications

Classes:
    Measurement: Container for a single performance measurement
    AutotuneData: Container for all optimization measurements and results
    Autotuner: Core autotuning engine for hyperparameter optimization
    Entry: Cache entry for storing optimal configurations
    AutotuningResult: Device-specific optimization results container
    TimingResult: Statistical timing result with mean and standard deviation
    FNAutotuner: Advanced class-based autotuner with profiler integration

Functions:
    autotune: Decorator for automatic hyperparameter optimization
    autotune_recorded: Autotune all recorded invocations
    benchmark: Simple function benchmarking utility

Example:
    >>> @autotune(hyperparams={'block_size': [64, 128, 256]})
    ... def matrix_op(x, y, block_size=128):
    ...     return compute(x, y, block_size)
    >>>
    >>> result = matrix_op(x, y)
    >>> print(matrix_op.optimal_hyperparams)
"""

from __future__ import annotations

import contextlib
import itertools
import logging
import os
import random as pyrandom
import re
import threading
import time
from collections.abc import Callable, Iterable
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Generic, TypeVar

import jax
import numpy as np
from jax import numpy as jnp
from jax import random
from jax.interpreters import pxla
from jax.sharding import PartitionSpec, Sharding, SingleDeviceSharding

from ejkernel.loggings import get_logger

from ..config.cache import overlay_cache
from ..utils.fingerprint import device_fingerprint
from .profiler import Profiler

autotune_logger = get_logger("ejkernel.autotune", "WARNING")


Cfg = TypeVar("Cfg")


[docs]@dataclass class Measurement: """Container for a single performance measurement. Stores the configuration and corresponding execution time for a single hyperparameter combination during optimization. Attributes: cfg: The hyperparameter configuration that was tested seconds: Execution time in seconds for this configuration """ cfg: Any seconds: float
[docs]@dataclass class AutotuneData(Generic[Cfg]): """Container for all optimization measurements and results. Stores performance measurements for all tested hyperparameter configurations and provides utilities to analyze the results. Type Parameters: Cfg: Configuration type (e.g., dict, dataclass, etc.) Attributes: measurements: List of all performance measurements taken """ measurements: list[Measurement] @property def fastest_config(self) -> Cfg: """Get the configuration with the fastest execution time. Finds the measurement with the minimum execution time among all recorded measurements and returns its configuration. Returns: The configuration that achieved the lowest execution time Raises: ValueError: If no measurements are available """ if not self.measurements: raise ValueError("No measurements available to determine fastest config") return min(self.measurements, key=lambda m: m.seconds).cfg
[docs]class Autotuner(Generic[Cfg]): """Core autotuning engine for hyperparameter optimization. This class provides the fundamental optimization algorithm that tests different configurations and measures their performance to find the optimal hyperparameter settings. Type Parameters: Cfg: Configuration type for hyperparameters Attributes: warmup: Number of warmup iterations before timing iters: Number of timing iterations for measurement accuracy """ def __init__(self, warmup=1, iters=3): """Initialize the autotuner with timing parameters. Args: warmup: Number of warmup calls to stabilize performance before timing iters: Number of timing iterations for statistical accuracy """ self.warmup, self.iters = warmup, iters
[docs] def autotune(self, make_fn, args, kwargs, candidates: Iterable[Cfg]) -> AutotuneData[Cfg]: """Optimize hyperparameters by testing candidate configurations. Tests each candidate configuration by compiling and timing the function execution, then returns all measurements for analysis. Args: make_fn: Factory function that creates a function given a config args: Positional arguments for the function being optimized kwargs: Keyword arguments for the function being optimized candidates: Iterable of candidate configurations to test Returns: AutotuneData containing all performance measurements """ measures = [] for cfg in candidates: try: fn = make_fn(cfg) c = jax.jit(fn).lower(*args, **kwargs).compile() for _ in range(self.warmup): _ = c(*args, **kwargs).block_until_ready() t0 = time.perf_counter() for _ in range(self.iters): _ = c(*args, **kwargs).block_until_ready() dt = (time.perf_counter() - t0) / self.iters measures.append(Measurement(cfg, dt)) except Exception as e: autotune_logger.warning(f"Configuration {cfg} failed: {e}") measures.append(Measurement(cfg, float("inf"))) if not measures or all(m.seconds == float("inf") for m in measures): autotune_logger.warning("All candidate configurations failed to execute; returning empty measurements.") return AutotuneData(measures)
[docs]@dataclass(frozen=True) class Entry: """Cache entry for storing optimal configurations. Represents a single cached optimization result with the operation identifier, call signature, and optimal configuration. Attributes: op_id_v: Operation identifier for the optimized function call_key: Hash key representing the function call signature cfg: The optimal configuration found for this operation """ op_id_v: str call_key: str cfg: Any
[docs]@dataclass(frozen=True) class AutotuningResult: """Result container for device-specific optimization results. Stores all optimized configurations for a specific device and provides context manager functionality for temporary cache overlays. Attributes: device: Device identifier these results apply to entries: Tuple of optimization entries (operation -> config mappings) """ device: str entries: tuple[Entry, ...]
[docs] def as_overlay(self): """Convert results to cache overlay mapping format. Creates a dictionary mapping that can be used with the cache overlay system to temporarily apply these optimization results. Returns: Dictionary mapping (device, op_id, call_key) tuples to configurations """ mapping = {(self.device, e.op_id_v, e.call_key): e.cfg for e in self.entries} return mapping
def __enter__(self): """Enter context manager to apply optimization results as cache overlay. Activates the cache overlay with the optimization results, temporarily overriding any existing cache entries with the optimized configurations. Returns: Self for use in with statements """ self._ctx = overlay_cache(self.as_overlay()) self._ctx.__enter__() return self def __exit__(self, exc_type, exc, tb): """Exit context manager and restore previous cache state. Deactivates the cache overlay and restores the previous cache state, ensuring clean cleanup even if exceptions occur. Args: exc_type: Exception type (if any) exc: Exception instance (if any) tb: Exception traceback (if any) """ self._ctx.__exit__(exc_type, exc, tb) delattr(self, "_ctx")
[docs]def autotune_recorded(hyperparameter_selector, *, show_progress=False, repetition_count=1): """Autotune all kernel invocations recorded for the current device. This function iterates through all kernel invocations that have been recorded in the global registry for the current device and runs autotuning to find optimal configurations for each unique operation/call-key combination. The autotuning process: 1. Retrieves all recorded invocations for the current device 2. For each recorded kernel call, prepares arguments and generates candidates 3. Benchmarks each candidate configuration 4. Stores the optimal configuration in both memory and persistent caches 5. Returns results as an AutotuningResult for context manager usage Args: hyperparameter_selector: ConfigSelectorChain instance with cache and persistent storage for storing optimization results show_progress: Whether to display progress bars during optimization (currently unused, reserved for future implementation) repetition_count: Number of times to repeat the optimization process (currently unused, reserved for future implementation) Returns: AutotuningResult containing all optimal configurations found. Can be used as a context manager to temporarily apply the configurations. Example: >>> from ejkernel.ops.config import ConfigCache, ConfigSelectorChain >>> cache = ConfigCache() >>> selector = ConfigSelectorChain(cache) >>> >>> # Record invocations by running with EJKERNEL_OPS_RECORD=1 >>> # Then autotune all recorded operations: >>> result = autotune_recorded(selector) >>> with result: ... # Runs with optimized configurations ... output = my_model(input_data) Note: Requires invocations to be previously recorded using the EJKERNEL_OPS_RECORD=1 environment variable during initial runs. """ from ..registry import get_invocations dev = device_fingerprint() invs = get_invocations(dev) entries = [] for op_id_v, d in invs.items(): for call_key, (kernel, args, kwargs) in d.items(): inv_args, inv_kwargs = kernel.prepare(*args, **kwargs) static_fun_kwargs = {k: v for k, v in inv_kwargs.items() if callable(v)} dyn_kwargs = {k: v for k, v in inv_kwargs.items() if k not in static_fun_kwargs} tmp_inv = type( "Tmp", (), dict( op_id=kernel.op_id, args=inv_args, kwargs=dyn_kwargs, batch_axes=None, override_cfg=None, stamp=False ), )() candidates = tuple(kernel.candidate_cfgs(tmp_inv)) def mk(c, _run=kernel.run, _static=static_fun_kwargs): """Create a partial function binding a specific configuration to the kernel run method.""" return partial(_run, cfg=c, **_static) best_cfg, best_t = None, float("inf") for c in candidates: t = benchmark(mk(c), *inv_args, **dyn_kwargs) if t < best_t: best_cfg, best_t = c, t hyperparameter_selector.cache.put(dev, op_id_v, call_key, best_cfg) if hyperparameter_selector.persistent and hyperparameter_selector.persist_autotune: hyperparameter_selector.persistent.put(dev, op_id_v, call_key, best_cfg) entries.append(Entry(op_id_v, call_key, best_cfg)) return AutotuningResult(dev, tuple(entries))
def _split_static_callable_kwargs(kwargs): """Split keyword arguments into static and dynamic components. Separates callable arguments (static) from regular arguments (dynamic) for proper JAX compilation and execution. Args: kwargs: Dictionary of keyword arguments to split Returns: Tuple of (static_kwargs, dynamic_kwargs) """ static = {k: v for k, v in kwargs.items() if callable(v)} return static, {k: v for k, v in kwargs.items() if k not in static}
[docs]def benchmark(fn, *args, warmup=1, iters=5, **kwargs) -> float: """Benchmark function execution time with JAX compilation. Compiles the function with JAX and measures its execution time over multiple iterations, handling both static and dynamic arguments. Args: fn: Function to benchmark *args: Positional arguments for the function warmup: Number of warmup iterations before timing iters: Number of timing iterations for measurement **kwargs: Keyword arguments for the function Returns: Average execution time per iteration in seconds """ static, dyn = _split_static_callable_kwargs(kwargs) if static: def fn_wrapped(*a, _fn=fn, _static=static, **k): """Wrap function to merge static callable kwargs with dynamic kwargs at call time.""" return _fn(*a, **(k | _static)) c = jax.jit(fn_wrapped).lower(*args, **dyn).compile() for _ in range(warmup): _ = c(*args, **dyn).block_until_ready() t0 = time.perf_counter() for _ in range(iters): _ = c(*args, **dyn).block_until_ready() return (time.perf_counter() - t0) / iters else: c = jax.jit(fn).lower(*args, **kwargs).compile() for _ in range(warmup): _ = c(*args, **kwargs).block_until_ready() t0 = time.perf_counter() for _ in range(iters): _ = c(*args, **kwargs).block_until_ready() return (time.perf_counter() - t0) / iters
[docs]@dataclass class TimingResult: """Statistical timing result for a single hyperparameter configuration. Stores the measured execution time statistics for a specific set of hyperparameters, including both mean and standard deviation for reliability analysis. Attributes: hyperparams: Dictionary of hyperparameter names to their tested values t_mean: Mean execution time in seconds across timing iterations t_std: Standard deviation of execution times in seconds """ hyperparams: dict[Any, Any] t_mean: float t_std: float
def _get_global_mesh(): """Retrieve the current global mesh from JAX thread resources. Accesses JAX's internal thread-local resources to get the currently active device mesh for distributed computation. Returns None if no mesh is configured or the mesh is empty. Returns: The current JAX mesh object, or None if no active mesh exists """ env = pxla.thread_resources.env mesh = env.physical_mesh return None if mesh.empty else mesh def _get_default_device(): """Get the default device for JAX computation. Returns the explicitly configured default device if set, otherwise falls back to the first available device in the system. Returns: The default JAX device for computation """ if jax.config.values["jax_default_device"] is not None: return jax.config.values["jax_default_device"] return jax.devices()[0] @contextlib.contextmanager def _suppress_stdout_stderr(): """Context manager to temporarily suppress stdout and stderr output. Redirects both stdout and stderr to /dev/null for the duration of the context, then restores them afterward. Useful for silencing noisy library output during compilation or benchmarking. Yields: None - This is a context manager with no return value Note: Uses file descriptor duplication to ensure proper restoration even if exceptions occur within the context. """ devnull = open(os.devnull, "w") stdout_fd, stderr_fd = os.dup(1), os.dup(2) try: os.dup2(devnull.fileno(), 1) os.dup2(devnull.fileno(), 2) yield finally: try: os.dup2(stdout_fd, 1) os.dup2(stderr_fd, 2) finally: os.close(stdout_fd) os.close(stderr_fd) devnull.close() def _normalize_sharding( arg: jax.Array | np.ndarray | Any, sharding_or_spec: PartitionSpec | Sharding | None, default_device: jax.Device, # type: ignore ): """Normalize sharding specification to a concrete Sharding object. Converts various sharding specifications (PartitionSpec, Sharding, None) into concrete Sharding objects that can be used for array placement. Handles both global mesh and single-device scenarios. Args: arg: Array or array-like object to be sharded sharding_or_spec: Sharding specification (PartitionSpec, Sharding, or None) default_device: Default device to use for single-device sharding Returns: Concrete Sharding object, or None for non-array arguments Raises: ValueError: If PartitionSpec is provided but no global mesh is defined """ if not isinstance(arg, jax.Array | np.ndarray): return None if isinstance(sharding_or_spec, Sharding): return sharding_or_spec global_mesh = _get_global_mesh() if isinstance(sharding_or_spec, PartitionSpec) and global_mesh is not None: return jax.NamedSharding(global_mesh, sharding_or_spec) elif isinstance(sharding_or_spec, PartitionSpec) and global_mesh is None: raise ValueError("If specifying shardings via ParitionSpec, a global mesh must be defined") else: return SingleDeviceSharding(default_device) def _ensure_dtype(dt): """Extract dtype from an array or return the input if already a dtype. Safely extracts the dtype attribute from arrays, handling edge cases where the dtype extraction might fail. Args: dt: JAX array, NumPy array, or dtype-like object Returns: The dtype of the input array, or the input itself if not an array """ try: return dt.dtype if isinstance(dt, jax.Array | np.ndarray) else dt except Exception: return dt @partial(jax.jit, static_argnames=("sds", "sharding")) def _get_random_value(sds, sharding=None): """Generate random values matching a shape/dtype specification. Creates random data matching the shape and dtype of the input specification. For floating point types, generates normally distributed random values. For integer types, generates zeros. Supports optional output sharding. Args: sds: Shape/dtype specification (ShapeDtypeStruct or similar object with shape and dtype attributes), or any other value to return unchanged sharding: Optional sharding specification for the output array Returns: Random array matching the specification, or the input unchanged if it doesn't have shape/dtype attributes Raises: ValueError: If the dtype is not floating point or integer """ if hasattr(sds, "shape") and hasattr(sds, "dtype"): dt = _ensure_dtype(sds.dtype) if jnp.issubdtype(dt, jnp.floating): return jax.jit(lambda key: random.normal(key, sds.shape, dt), out_shardings=sharding)(random.key(0)) elif jnp.issubdtype(dt, jnp.integer): return jax.jit(lambda: jnp.zeros(sds.shape, dt), out_shardings=sharding)() else: raise ValueError(f"Unsupported dtype {dt}") else: return sds def _try_hash_input(args, kws, must_be_concrete: bool = True): """Attempt to create a hashable key from function input arguments. Creates a hash key based on the structure and types of input arguments, which can be used for caching autotuning results. Arrays are hashed based on their shape, dtype, and sharding rather than their values. Args: args: Positional arguments to hash kws: Keyword arguments to hash must_be_concrete: If True, returns None when any arrays are abstract (e.g., inside JAX transformations). Default is True. Returns: A hash integer uniquely identifying the input signature, or None if: - must_be_concrete is True and arguments contain abstract arrays - Hashing fails for any reason (e.g., unhashable types) """ flat_vals, struct = jax.tree.flatten((args, kws)) all_concrete = all(jax.core.is_concrete(x) for x in flat_vals if isinstance(x, jax.Array)) if not all_concrete and must_be_concrete: return None def _get_sharding(x): """Extract sharding from array or its abstract type.""" try: return x.sharding except AttributeError: return jax.typeof(x).sharding def array_to_hashable(x): """Convert array to hashable representation based on type and sharding.""" return x if not isinstance(x, jax.Array) else hash((jax.typeof(x), _get_sharding(x))) try: return hash((struct, tuple(array_to_hashable(x) for x in flat_vals))) except Exception: return None
[docs]class FNAutotuner: """Advanced class-based JAX autotuner with profiler-first timing and Python fallback. Provides comprehensive hyperparameter optimization for JAX functions using JAX's native profiling infrastructure when available, with automatic fallback to Python-level timing. Supports parallel compilation, statistical timing analysis, and intelligent caching of optimization results. Key Features: - Profiler-based timing for accurate GPU/TPU measurements - Python-level timing fallback when profiler unavailable - Parallel compilation of hyperparameter configurations - Statistical analysis with outlier removal - Thread-safe caching of optimal configurations - Optional automatic memory layout optimization The autotuner works by: 1. Generating all hyperparameter combinations 2. Compiling each configuration in parallel 3. Timing execution using profiler or Python fallback 4. Selecting the configuration with best performance 5. Caching results for future calls Attributes: allow_fallback_timing: Whether to use Python timing when profiler fails profiling_samples: Number of profiling iterations for statistics must_find_profiler_fraction: Minimum fraction of configs needing profiler results enable_detailed_logging: Enable verbose error logging find_optimal_layouts_automatically: Auto-discover optimal memory layouts max_compilation_time_seconds: Maximum compilation time per config timing_warmup_iterations: Warmup iterations before timing timing_rounds: Number of timing rounds for statistics calls_per_round: Function calls per timing round cache_size_limit: Maximum cached optimization results profiler: Profiler instance for trace capture and analysis """ PREFIX_FN = "autotune_fn_{}" def __init__( self, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = "jit_", profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, ): """Initialize the autotuner with timing and profiling configuration. Args: allow_fallback_timing: Enable Python timing fallback if profiler fails profiling_samples: Number of profiling iterations for statistical accuracy must_find_profiler_fraction: Minimum fraction of configs needing profiler results (0.0-1.0). Falls back to Python timing if not met. enable_detailed_logging: Enable detailed error logging with tracebacks find_optimal_layouts_automatically: Auto-discover optimal memory layouts for distributed computation max_compilation_time_seconds: Maximum compilation time per configuration timing_warmup_iterations: Warmup iterations before timing measurements timing_rounds: Number of timing rounds for statistical analysis calls_per_round: Function calls per timing round cache_size_limit: Maximum number of cached optimization results profiler_prefix_filter: Event name prefix filter for profiler profiler_event_regex: Optional regex filter for profiler events profiler_min_duration_ns: Minimum event duration for profiler inclusion profiler_max_events: Maximum events per profile to prevent memory issues profiler_verbose: Enable verbose profiler output """ self.allow_fallback_timing = allow_fallback_timing self.profiling_samples = profiling_samples self.must_find_profiler_fraction = must_find_profiler_fraction self.enable_detailed_logging = enable_detailed_logging self.find_optimal_layouts_automatically = find_optimal_layouts_automatically self.max_compilation_time_seconds = max_compilation_time_seconds self.timing_warmup_iterations = timing_warmup_iterations self.timing_rounds = timing_rounds self.calls_per_round = calls_per_round self.cache_size_limit = cache_size_limit self.profiler = Profiler( prefix_filter=profiler_prefix_filter, event_filter_regex=profiler_event_regex, min_duration_ns=profiler_min_duration_ns, max_events_per_profile=profiler_max_events, verbose=profiler_verbose or (autotune_logger._level <= logging.INFO), ) self._cache_lock = threading.Lock() @staticmethod def _calculate_timing_score(tr: TimingResult) -> float: """Calculate a composite timing score for ranking configurations. Combines mean execution time with standard deviation to create a single score that balances speed with consistency. Penalizes configurations with high variance even if they have good mean performance. The weighting factor of 0.1 for standard deviation provides a reasonable balance between speed and stability. Args: tr: TimingResult containing mean and standard deviation Returns: Composite score (lower is better) = mean + 0.1 * std """ return tr.t_mean + 0.1 * tr.t_std def _create_parameterized_function( self, target_function: Callable[..., Any], hyperparameter_values: dict[str, Any], output_shardings: Any = None, function_id: int = 0, ) -> Callable[..., Any]: """Create a JIT-compiled function with embedded hyperparameters. Wraps the target function with specific hyperparameter values and compiles it using JAX JIT. Sets a unique name for profiler identification and optionally configures output sharding for distributed computation. The function name follows the pattern 'autotune_fn_{function_id}' for easy identification in profiler output. Args: target_function: Original function to parameterize hyperparameter_values: Dictionary of hyperparameter names to values output_shardings: Optional sharding specification for outputs function_id: Unique identifier for profiler tracking Returns: JIT-compiled function with embedded hyperparameters and unique name """ jax_compiler = partial(jax.jit, out_shardings=output_shardings) def parameterized_function(*function_args, **function_kwargs): """Execute the target function with embedded hyperparameter values merged into kwargs.""" combined_kwargs = dict(function_kwargs, **hyperparameter_values) return target_function(*function_args, **combined_kwargs) function_name = self.PREFIX_FN.format(function_id) parameterized_function.__name__ = function_name parameterized_function.__qualname__ = function_name return jax_compiler(parameterized_function) def _try_call( self, fn: Callable[..., Any], resolved_args, resolved_kwargs, compile_only: bool = False, compute_layouts: bool = False, optimal_formats: Any | None = None, timeout: float | None = None, ): """Safely compile or execute a function with comprehensive error handling. Attempts to compile or execute a function with the provided arguments, handling various failure modes gracefully. Supports layout optimization and optimal device placement for distributed computation. Args: fn: Function to compile or execute resolved_args: Positional arguments for the function resolved_kwargs: Keyword arguments for the function compile_only: If True, only compile without execution compute_layouts: If True, compute optimal memory layouts optimal_formats: Optional pre-computed optimal device formats timeout: Optional compilation timeout (currently unused) Returns: Tuple of (success_bool, error_message, optimal_input_formats) where success_bool indicates if operation succeeded, error_message contains failure details if any, and optimal_input_formats contains discovered layouts if computed """ optimal_input_formats = None try: if compile_only: if compute_layouts: def to_shape(x): """Convert a JAX array to its ShapeDtypeStruct representation for layout optimization.""" return ( jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding) if isinstance(x, jax.Array) else x ) (argument_shapes, keyword_shapes) = jax.tree.map(to_shape, (resolved_args, resolved_kwargs)) try: compiled_function = jax.jit(fn).lower(*argument_shapes, **keyword_shapes).compile() optimal_input_formats = getattr(compiled_function, "input_formats", None) except Exception as compilation_error: autotune_logger.warning( f"Layout optimization failed during compilation: " f"{compilation_error.__class__.__name__}: {compilation_error}" ) optimal_input_formats = None else: _ = jax.jit(fn).lower(*resolved_args, **resolved_kwargs).compile() else: if optimal_formats is not None: def place_array_on_optimal_device(array_data, target_format): """Place an array on its optimal device layout, passing non-arrays through.""" return ( jax.device_put(array_data, target_format) if isinstance(array_data, jax.Array) else array_data ) try: (optimally_placed_args, optimally_placed_kwargs) = jax.tree.map( place_array_on_optimal_device, (resolved_args, resolved_kwargs), optimal_formats ) _ = jax.block_until_ready(fn(*optimally_placed_args, **optimally_placed_kwargs)) except Exception: autotune_logger.warning( "Failed to place arrays on optimal devices - falling back to original argument placement" ) _ = jax.block_until_ready(fn(*resolved_args, **resolved_kwargs)) else: _ = jax.block_until_ready(fn(*resolved_args, **resolved_kwargs)) return True, None, optimal_input_formats except Exception as e: msg = f"{type(e).__name__}: {e!s}" if self.enable_detailed_logging: import traceback msg = traceback.format_exc() return False, msg, optimal_input_formats def _time_fn(self, target_function: Callable[[], None]) -> tuple[float, float]: """Perform high-precision Python-level timing with statistical analysis. Executes the target function multiple times with warmup iterations, measures execution times, and provides statistical analysis with outlier removal for reliable performance measurements. Uses block_until_ready() to ensure accurate timing of JAX computations. Args: target_function: Zero-argument function to time (should return JAX arrays) Returns: Tuple of (mean_time, std_time) in seconds after outlier removal, or (inf, inf) if all timing attempts failed """ def _execute_and_block(): """Execute the target function and block until all outputs are ready.""" return jax.block_until_ready(target_function()) for _ in range(self.timing_warmup_iterations): try: _execute_and_block() except Exception as warmup_error: autotune_logger.warning(f"Warmup failed: {warmup_error.__class__.__name__}: {warmup_error}") return float("inf"), float("inf") times = [] for _ in range(self.timing_rounds): t0 = time.perf_counter() try: for _ in range(self.calls_per_round): _execute_and_block() times.append(time.perf_counter() - t0) except Exception as timing_error: autotune_logger.warning(f"Timing round failed: {timing_error.__class__.__name__}: {timing_error}") times.append(float("inf")) valid = [t for t in times if not np.isinf(t)] if not valid: return float("inf"), float("inf") arr = np.array(valid) / self.calls_per_round arr = np.sort(arr) if len(arr) > 2: arr = arr[1:-1] return float(np.mean(arr)), float(np.std(arr))
[docs] def tune( self, fn: Callable[..., Any], *, args: tuple[Any, ...], kwargs: dict[str, Any], hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[Any, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, ) -> tuple[Callable[..., Any], dict[str, Any], list[tuple[int, TimingResult]]]: """Tune hyperparameters for a function and return optimal configuration. Performs comprehensive hyperparameter optimization by testing all candidate configurations in parallel, measuring their performance, and returning the best-performing configuration. Args: fn: Function to optimize (must accept hyperparameters as keyword args) args: Positional arguments for the function (can be abstract shapes) kwargs: Keyword arguments for the function (excluding hyperparameters) hyperparams: Dictionary mapping hyperparameter names to lists of candidate values to test. Each combination will be evaluated. max_workers: Maximum number of parallel workers for compilation in_shardings: Input sharding specifications for distributed computation out_shardings: Output sharding specifications for distributed computation device: Target device or device string for computation example_args: Concrete example arguments when args are abstract example_kws: Concrete example kwargs when kwargs are abstract sample_num: Maximum number of hyperparameter combinations to test. If fewer than total combinations, samples randomly. event_filter_regex: Optional regex filter for profiler events timeout: Optional compilation timeout override Returns: Tuple of (parameterized_fn, optimal_hyperparams, timing_results_sorted): - parameterized_fn: JIT-compiled function with optimal hyperparameters - optimal_hyperparams: Dictionary of optimal hyperparameter values - timing_results_sorted: List of (index, TimingResult) sorted by performance Raises: TypeError: If fn is not callable ValueError: If max_workers <= 0, sample_num < 0, or hyperparameters invalid RuntimeError: If all hyperparameter configurations fail to compile """ if not callable(fn): raise TypeError("fn must be callable") if max_workers <= 0: raise ValueError("max_workers must be positive") if sample_num < 0: raise ValueError("sample_num must be non-negative") if event_filter_regex is not None: self.profiler._pattern = re.compile(event_filter_regex) def _extract_array_type(x): """Extract the abstract type of a JAX array, or pass through non-array values.""" return x if not isinstance(x, jax.Array) else jax.typeof(x) if len(args) == 0 or all(x is None or jax.core.is_concrete(x) for x in jax.tree.leaves(args)): resolved_args = args elif example_args is not None: if in_shardings is not None or device is not None: raise ValueError( "Cannot combine example_args with explicit in_shardings or device configuration. " "Example arguments should already be properly sharded and placed." ) resolved_args = example_args else: resolved_device = device if isinstance(device, jax.Device) else _get_default_device() if isinstance(resolved_device, str): resolved_device = jax.devices(resolved_device)[0] input_shardings = in_shardings if in_shardings is not None else jax.tree.map(lambda _: None, args) normalized_shardings = jax.tree.map( partial(_normalize_sharding, default_device=resolved_device), args, input_shardings, ) resolved_args = jax.tree.map( lambda x, s: _get_random_value(_extract_array_type(x), s), args, normalized_shardings ) if len(kwargs) == 0 or all(v is None or jax.core.is_concrete(v) for v in kwargs.values()): resolved_kwargs = kwargs elif example_kws is not None: resolved_kwargs = example_kws else: resolved_kwargs = jax.tree.map(lambda x: _get_random_value(_extract_array_type(x)), kwargs) hyperparams = hyperparams if hyperparams is not None else {} hyperparams_norm: dict[str, tuple[Any, ...]] = {} for k, v in hyperparams.items(): if isinstance(v, tuple | list): if len(v) == 0: raise ValueError(f"Hyperparameter '{k}' has empty list of values") hyperparams_norm[k] = tuple(v) else: hyperparams_norm[k] = (v,) if hyperparams_norm: hyperparam_settings = dict(enumerate(itertools.product(*hyperparams_norm.values()))) total_combinations = len(hyperparam_settings) if sample_num < total_combinations: if sample_num == 0: hyperparam_settings = {0: tuple()} else: sample_idx = sorted( pyrandom.sample(list(range(total_combinations)), k=min(sample_num, total_combinations)) ) hyperparam_settings_ = list(hyperparam_settings.items()) hyperparam_settings = dict([hyperparam_settings_[idx] for idx in sample_idx]) autotune_logger.info( f"Testing {len(hyperparam_settings)} hyperparameter combinations out of {total_combinations} possible" ) else: hyperparam_settings = {0: tuple()} with ThreadPoolExecutor(max_workers=max_workers) as executor: fns: dict[int, Callable[..., Any]] = {} optimal_formats: dict[int, Any] = {} for phase in range(2): compile_only = phase == 0 compute_layouts = self.find_optimal_layouts_automatically compiles = {} for i, vals in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), vals, strict=True)) fns[i] = self._create_parameterized_function(fn, hs, output_shardings=out_shardings, function_id=i) opts = dict( optimal_formats=optimal_formats.get(i, None), compute_layouts=compute_layouts, timeout=self.max_compilation_time_seconds if timeout is None else timeout, ) fut = executor.submit( self._try_call, fns[i], resolved_args, resolved_kwargs, compile_only=compile_only, **opts ) compiles[fut] = i successful = {} for fut in as_completed(compiles): status, err, optf = fut.result() if status: successful[compiles[fut]] = (status, err, optf) if compile_only and compute_layouts: for i, (_, _, optf) in successful.items(): optimal_formats[i] = optf if not successful: for fut, i in compiles.items(): _, err, _ = fut.result() autotune_logger.error( f"Hyperparameters {hyperparam_settings[i]} failed to compile with message:\n{err}" ) raise ValueError("No hyperparameters compiled successfully") hyperparam_settings = {i: hyperparam_settings[i] for i in successful.keys()} fns = {i: fns[i] for i in successful.keys()} results: dict[int, TimingResult] = {} try: args_with_device = [ next(iter(arg.devices())) for arg in jax.tree.leaves(resolved_args) if hasattr(arg, "devices") ] platform = args_with_device[0].platform if len(args_with_device) > 0 else _get_default_device().platform def _timing_closure(): """Execute all hyperparameter configurations in random order for profiler timing.""" settings = list(hyperparam_settings.items()) pyrandom.shuffle(settings) for i, _ in settings: self._try_call( fns[i], resolved_args, resolved_kwargs, compile_only=False, optimal_formats=optimal_formats.get(i, None), ) profiler_timings = self.profiler.profile_time_by_function_id( _timing_closure, platform, self.profiling_samples, ) fraction_measured = sum(1 for i in hyperparam_settings.keys() if i in profiler_timings) / len( hyperparam_settings ) if fraction_measured < self.must_find_profiler_fraction: missing = [i for i in hyperparam_settings.keys() if i not in profiler_timings] msg = "Could not find profiler results for some hyperparameter settings:\n" + "\n".join( f" - {i}: {hyperparam_settings[i]}" for i in missing ) raise RuntimeError(msg) for i in hyperparam_settings.keys(): if i not in profiler_timings: autotune_logger.warning( f"Could not find profiler results for hyperparameter settings: {hyperparam_settings[i]}" ) profiler_timings[i] = (float("inf"), float("inf")) for i, hs in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), hs, strict=True)) t_mean, t_std = profiler_timings[i] results[i] = TimingResult(hs, float(t_mean), float(t_std)) except Exception as e: if not self.allow_fallback_timing: raise RuntimeError( f"Need to fall back to the python-level timing, but allow_fallback_timing=False. Error: {e}" ) from None for i, hs in hyperparam_settings.items(): hs = dict(zip(hyperparams_norm.keys(), hs, strict=True)) t_mean, t_std = self._time_fn(partial(lambda fn: fn(*resolved_args, **resolved_kwargs), fns[i])) results[i] = TimingResult(hs, t_mean, t_std) results_sorted = sorted(results.items(), key=lambda x: self._calculate_timing_score(x[1])) idx, optimal_hyperparams = results_sorted[0][0], results_sorted[0][1].hyperparams return fns[idx], optimal_hyperparams, results_sorted
[docs] def decorate( self, fn: Callable[..., Any], *, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None, ): """Create a decorated version of a function with automatic hyperparameter tuning. Wraps a function so that the first call triggers hyperparameter optimization, and subsequent calls with the same input signature use cached optimal values. Args: fn: Function to decorate with autotuning capabilities hyperparams: Dictionary mapping hyperparameter names to lists of candidate values to test during optimization max_workers: Maximum number of parallel workers for compilation in_shardings: Input sharding specifications for distributed computation out_shardings: Output sharding specifications for distributed computation device: Target device or device string for computation example_args: Concrete example arguments when function args are abstract example_kws: Concrete example kwargs when function kwargs are abstract sample_num: Maximum number of hyperparameter combinations to test event_filter_regex: Optional regex filter for profiler events timeout: Optional compilation timeout override cache_key: Optional custom cache key prefix for disambiguation Returns: Decorated function with automatic hyperparameter optimization. The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values Raises: TypeError: If fn is not callable """ if not callable(fn): raise TypeError("fn must be callable") cache: dict[Any, tuple[dict[str, Any], list[tuple[int, TimingResult]]]] = {} cache_lock = threading.Lock() @wraps(fn) def wrapped(*args, **kws): """Wrapper function that performs autotuning on first call and caches results. On first call with a given input signature, performs hyperparameter optimization and caches the results. Subsequent calls with the same signature use the cached optimal hyperparameters. Args: *args: Positional arguments for the wrapped function **kws: Keyword arguments for the wrapped function Returns: Result of executing the function with optimal hyperparameters Raises: Exception: Any exception from the underlying function execution """ input_hash = _try_hash_input(args, kws) lookup_key = f"{cache_key}:{input_hash}" if cache_key and input_hash else input_hash with cache_lock: hit = lookup_key is not None and lookup_key in cache if hit: optimal_hyperparams, results = cache[lookup_key] else: optimal_hyperparams = None results = None if optimal_hyperparams is None: flat_vals = jax.tree.leaves((args, kws)) has_tracers = any(not jax.core.is_concrete(x) for x in flat_vals if isinstance(x, jax.Array)) if has_tracers: if wrapped.optimal_hyperparams: optimal_hyperparams = wrapped.optimal_hyperparams.copy() results = wrapped.timing_results else: optimal_hyperparams = { k: (v[0] if isinstance(v, list) else v) for k, v in (hyperparams or {}).items() } results = [] else: with jax.core.eval_context(): _, optimal_hyperparams, results = self.tune( fn, args=args, kwargs=kws, hyperparams=hyperparams, max_workers=max_workers, in_shardings=in_shardings, out_shardings=out_shardings, device=device, example_args=example_args, example_kws=example_kws, sample_num=sample_num, event_filter_regex=event_filter_regex, timeout=timeout, ) with cache_lock: if lookup_key is not None: if len(cache) >= self.cache_size_limit: oldest_key = next(iter(cache)) del cache[oldest_key] cache[lookup_key] = (optimal_hyperparams, results) wrapped.timing_results = results wrapped.optimal_hyperparams = optimal_hyperparams try: return fn(*args, **dict(kws, **optimal_hyperparams)) except Exception as e: autotune_logger.error(f"Execution failed with optimal hyperparameters {optimal_hyperparams}: {e}") raise wrapped.timing_results = [] wrapped.optimal_hyperparams = {} return wrapped
[docs]def autotune( fn: Callable[..., Any] | None = None, /, *, allow_fallback_timing: bool = True, profiling_samples: int = 5, must_find_profiler_fraction: float = 0.5, enable_detailed_logging: bool = False, find_optimal_layouts_automatically: bool = False, max_compilation_time_seconds: float = 300.0, timing_warmup_iterations: int = 2, timing_rounds: int = 5, calls_per_round: int = 3, cache_size_limit: int = 1000, profiler_prefix_filter: str = "jit_", profiler_event_regex: str | None = None, profiler_min_duration_ns: float = 1000.0, profiler_max_events: int | None = 10000, profiler_verbose: bool = False, hyperparams: dict[str, list[int | float | str]] | dict[Any, Any] | None = None, max_workers: int = 32, in_shardings: Any = None, out_shardings: Any = None, device: jax.Device | str | None = None, # type: ignore example_args: tuple[Any, ...] | None = None, example_kws: dict[str, Any] | None = None, sample_num: int = 2**63 - 1, event_filter_regex: str | None = None, timeout: float | None = None, cache_key: str | None = None, ): """Advanced JAX function autotuning decorator with comprehensive optimization features. A flexible decorator that automatically optimizes JAX functions by testing different hyperparameter configurations. Uses profiler-based timing for accuracy with Python fallback, supports parallel execution, and provides intelligent caching. This decorator can be used in three different ways: 1) @autotune() -> decorator factory with custom parameters 2) @autotune -> plain decorator with default parameters 3) autotune(fn, kw=...) -> direct function call returning wrapped function The decorator automatically optimizes hyperparameters on the first function call and caches results for subsequent calls with the same input signature. Key Features: - Profiler-based timing with Python-level fallback for maximum accuracy - Parallel compilation and testing of hyperparameter configurations - Automatic optimal memory layout discovery for distributed computation - Thread-safe caching with configurable size limits - Statistical timing analysis with outlier removal - Comprehensive error handling and detailed logging options - Support for both concrete and abstract input specifications Args: fn: Function to autotune (when used as direct call) allow_fallback_timing: Enable Python timing fallback if profiler fails profiling_samples: Number of profiling iterations for statistical accuracy must_find_profiler_fraction: Minimum fraction of configs needing profiler results (0.0-1.0) enable_detailed_logging: Enable detailed error logging with full tracebacks timing_warmup_iterations: Warmup iterations before timing measurements timing_rounds: Number of timing rounds for statistical analysis calls_per_round: Function calls per timing round profiler_prefix_filter: Event name prefix filter for profiler (default: "jit_") profiler_event_regex: Optional regex filter for profiler events profiler_min_duration_ns: Minimum event duration for profiler inclusion (nanoseconds) profiler_max_events: Maximum events per profile to prevent memory issues profiler_verbose: Enable verbose profiler output find_optimal_layouts_automatically: Auto-discover optimal memory layouts max_compilation_time_seconds: Maximum compilation time per configuration cache_size_limit: Maximum cached optimization results hyperparams: Dictionary mapping parameter names to candidate value lists max_workers: Maximum parallel workers for optimization sample_num: Maximum hyperparameter combinations to test in_shardings: Input sharding specifications for distributed computation out_shardings: Output sharding specifications for distributed computation device: Target device or device string for computation example_args: Concrete example arguments for abstract input shapes example_kws: Concrete example kwargs for abstract input values event_filter_regex: Optional regex filter for profiler events (alias for profiler_event_regex) timeout: Optional compilation timeout override cache_key: Optional custom cache key prefix for disambiguation Returns: When used as decorator: Decorated function with automatic hyperparameter optimization When used directly: Wrapper function that performs optimization and execution The returned function has additional attributes after first execution: - timing_results: List of all timing measurements from optimization - optimal_hyperparams: Dictionary of optimal parameter values Raises: TypeError: If fn is not callable (when used as direct call) ValueError: If hyperparameter specifications are invalid RuntimeError: If all hyperparameter configurations fail to compile Examples: Basic usage with hyperparameter optimization: ```python @autotune(hyperparams={'block_size': [64, 128, 256, 512]}) def matrix_multiply(a, b, block_size=128): return jnp.dot(a, b) result = matrix_multiply(x, y) print(f"Optimal config: {matrix_multiply.optimal_hyperparams}") ``` Advanced usage with custom timing and sharding: ```python @autotune( hyperparams={ 'chunk_size': [32, 64, 128], 'algorithm': ['parallel', 'sequential'] }, profiling_samples=10, max_workers=16, in_shardings=jax.sharding.PartitionSpec('data', None), enable_detailed_logging=True ) def compute_function(data, chunk_size=64, algorithm='parallel'): return processed_data ``` Direct call usage: ```python optimized_fn = autotune( my_function, hyperparams={'param': [1, 2, 3]}, profiling_samples=5 ) result = optimized_fn(input_data) ``` Using with abstract input specifications: ```python @autotune( hyperparams={'tile_size': [16, 32, 64]}, example_args=(jnp.zeros((1000, 1000)),), device='gpu' ) def process_matrix(matrix, tile_size=32): return result ``` Note: - First call triggers optimization and may take longer - Subsequent calls with same input signature use cached optimal parameters - Profiler-based timing requires TensorFlow backend; falls back to Python timing - For distributed computation, ensure proper mesh configuration before use - Cache keys are based on input signatures; use cache_key for manual disambiguation """ tuner = FNAutotuner( allow_fallback_timing=allow_fallback_timing, profiling_samples=profiling_samples, must_find_profiler_fraction=must_find_profiler_fraction, enable_detailed_logging=enable_detailed_logging, find_optimal_layouts_automatically=find_optimal_layouts_automatically, max_compilation_time_seconds=max_compilation_time_seconds, timing_warmup_iterations=timing_warmup_iterations, timing_rounds=timing_rounds, calls_per_round=calls_per_round, cache_size_limit=cache_size_limit, profiler_prefix_filter=profiler_prefix_filter, profiler_event_regex=profiler_event_regex, profiler_min_duration_ns=profiler_min_duration_ns, profiler_max_events=profiler_max_events, profiler_verbose=profiler_verbose, ) def decorator(func: Callable[..., Any]): """Internal decorator function that applies autotuning to a target function. Args: func: Function to wrap with autotuning capabilities Returns: Function decorated with automatic hyperparameter optimization """ return tuner.decorate( func, hyperparams=hyperparams, max_workers=max_workers, in_shardings=in_shardings, out_shardings=out_shardings, device=device, example_args=example_args, example_kws=example_kws, sample_num=sample_num, event_filter_regex=event_filter_regex, timeout=timeout, cache_key=cache_key, ) if callable(fn): return decorator(fn) return decorator