Source code for ejkernel.ops.config.selection

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Configuration selection and autotuning system for kernel optimization.

This module provides a comprehensive configuration selection framework that
intelligently chooses optimal kernel configurations through a multi-tier
fallback chain. The system prioritizes cached results while supporting
automatic performance optimization when needed.

Key Components:
    ConfigSelectorChain: Main selection coordinator with fallback hierarchy
    AutotunePolicy: Configuration policy for autotuning behavior
    Tuner: Performance benchmarking and autotuning engine
    policy_override: Context manager for temporary policy changes

Selection Hierarchy (in order of priority):
    1. Override: Explicit configuration provided by caller
    2. Overlay: Temporary context-specific configuration overrides
    3. Memory Cache: Fast lookup for recently used configurations
    4. Persistent Cache: Disk-based storage across program runs
    5. Autotune: Benchmark candidates to find optimal configuration
    6. Heuristics: Kernel-provided default configuration
    7. Error: No configuration available (throws exception)

This design ensures optimal performance by:
    - Prioritizing fastest lookup methods (memory cache)
    - Preserving optimization results across runs (persistent cache)
    - Automatically finding optimal configurations (autotuning)
    - Providing sensible defaults (heuristics) as fallback

Example Usage:
    >>> cache = ConfigCache()
    >>> policy = AutotunePolicy(allow_autotune=True)
    >>> selector = ConfigSelectorChain(cache, policy)
    >>>
    >>>
    >>> config = selector.choose(invocation, kernel)
    >>>
    >>>
    >>> with policy_override(selector, allow_autotune=False):
    ...     config = selector.choose(invocation, kernel)
"""

from __future__ import annotations

import os
import pprint
import time
import traceback
from collections.abc import Iterable
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
from jax import core as jcore
from jax import tree_util as jtu

from ejkernel.loggings import get_logger

from ..core import Invocation, Kernel, _get_platform_method
from ..utils.fingerprint import device_fingerprint, get_device_platform
from .cache import ConfigCache, _cache_overlay
from .persistent import PersistentCache

Cfg = TypeVar("Cfg")
Out = TypeVar("Out")

autotune_logger = get_logger("ejKernel-Selection")
_backward_autotune_enabled: ContextVar[bool] = ContextVar("ejkernel_backward_autotune_enabled", default=True)


[docs]@dataclass class AutotunePolicy: """Configuration policy for autotuning behavior. Controls how the configuration selection system behaves when making optimization decisions, including whether to run autotuning, use heuristics, and validate backward pass correctness. Attributes: allow_autotune: Whether autotuning is permitted. When True, the system can benchmark multiple configurations to find the optimal one. allow_heuristics: Whether heuristic configurations are allowed as a fallback when no cached configuration is available. cache_miss_fallback: Strategy when no cached config is found. Either "autotune" to benchmark candidates or "heuristics" to use defaults. validate_backward: Whether to validate backward pass during autotuning. When True, autotuning will measure gradient computation time in addition to forward pass, ensuring the selected configuration performs well for training workloads. """ allow_autotune: bool = True allow_heuristics: bool = True cache_miss_fallback: Literal["autotune", "heuristics"] = "autotune" validate_backward: bool = False
[docs]class policy_override: """Context manager for temporarily overriding autotuning policy settings. Allows temporary modification of AutotunePolicy attributes within a context, automatically restoring the original values when exiting the context. This is useful for: - Disabling autotuning for specific operations - Forcing use of heuristics during debugging - Testing different policy configurations Args: selector: ConfigSelectorChain instance to modify **updates: Policy attributes to override with new values Example: >>> with policy_override(selector, allow_autotune=False): ... result = executor(kernel, *args) >>> >>> with policy_override(selector, cache_miss_fallback="heuristics"): ... config = selector.choose(inv, kernel) """ def __init__(self, selector: ConfigSelectorChain, **updates): """Initialize policy override context manager. Args: selector: ConfigSelectorChain to modify **updates: Policy attributes to override """ self.selector = selector self.updates = updates self._prev = {} def __enter__(self): """Enter context and apply policy overrides. Returns: Self for use in with statements """ for k, v in self.updates.items(): self._prev[k] = getattr(self.selector.policy, k) setattr(self.selector.policy, k, v) return self def __exit__(self, *exc): """Exit context and restore original policy values. Args: *exc: Exception information (ignored) """ for k, v in self._prev.items(): setattr(self.selector.policy, k, v)
[docs]class forward_autotune_only: """Context manager that disables backward validation during autotuning. While active, autotune measurements run forward-only even when ``AutotunePolicy.validate_backward`` is True. This keeps autotuning focused on forward latency and avoids gradient-timing overhead. Example: >>> with forward_autotune_only(): ... cfg = selector.choose(inv, kernel) """ def __init__(self): self._token = None def __enter__(self): self._token = _backward_autotune_enabled.set(False) return self def __exit__(self, *exc): if self._token is not None: _backward_autotune_enabled.reset(self._token)
def _is_backward_autotune_enabled() -> bool: """Return whether backward validation is currently enabled for autotune.""" return bool(_backward_autotune_enabled.get())
[docs]class Tuner(Generic[Cfg]): """Performance benchmarking and autotuning for kernel configurations. Measures execution time of different configurations and selects the fastest one. Attributes: warmup: Number of warmup iterations before timing iters: Number of timing iterations to average over """ def __init__(self, warmup=1, iters=3): """Initialize tuner with warmup and iteration settings. Args: warmup: Number of warmup iterations before timing (default: 1) iters: Number of timed iterations to average over (default: 3) """ self.warmup, self.iters = warmup, iters
[docs] def measure(self, fn, *args, **kwargs) -> float: """Measure average execution time with optional backward validation. Deep-flatten (args, kwargs) so only array-like leaves are dynamic: - Arrays or JAX tracers become dynamic parameters to the jitted function. - Everything else (dtype, strings, bools, callables, nested containers) is captured as Python constants in the closure. - Tracer-like arrays (e.g., ShardMapTracer, DynamicJaxprTracer) are converted to concrete zeros of the same shape/dtype before compile and timing. - If _ejk_validate_backward=True, we differentiate a scalar loss w.r.t. float/complex array leaves only; others are treated as non-diff. - If a kernel uses precompiled functions that can't be transformed, we fall back to forward-only timing, and if needed, to non-jitted forward timing. Args: fn: Function to measure (possibly tagged with _ejk_validate_backward) *args: Positional arguments **kwargs: Keyword arguments Returns: Average execution time per iteration in seconds """ def _is_arrayish(x) -> bool: """Check if a value is an array-like object (JAX Array, NumPy array, or JAX Tracer).""" return isinstance(x, jax.Array | np.ndarray) or isinstance(x, jcore.Tracer) def _to_concrete(x): """Convert a tracer or abstract value to a concrete JAX array. Handles JAX tracers by extracting shape/dtype from their abstract value and creating zero-filled arrays. Passes through concrete arrays unchanged. Args: x: Value to convert (array, tracer, or scalar). Returns: Concrete JAX array with the same shape and dtype as the input. """ if isinstance(x, jax.Array | np.ndarray): return x shape = getattr(x, "shape", None) dtype = getattr(x, "dtype", None) aval = getattr(x, "aval", None) if (shape is None or dtype is None) and aval is not None: shape = getattr(aval, "shape", None) dtype = getattr(aval, "dtype", None) if shape is not None and dtype is not None: return jnp.zeros(shape, dtype) return jnp.asarray(x) def _block_all(x): """Block until all arrays in the pytree are ready for synchronous timing.""" return jtu.tree_map(lambda t: t.block_until_ready() if hasattr(t, "block_until_ready") else t, x) leaves, treedef = jtu.tree_flatten((args, kwargs)) is_arr = [_is_arrayish(x) for x in leaves] const_leaves = [None if m else x for m, x in zip(is_arr, leaves, strict=False)] arr_leaves = [x for m, x in zip(is_arr, leaves, strict=False) if m] arr0 = tuple(_to_concrete(x) for x in arr_leaves) def _restore_args_kwargs(array_leaves): """Rebuild (args, kwargs) by merging dynamic array leaves with closed-over constants.""" it = iter(array_leaves) merged = [next(it) if m else v for m, v in zip(is_arr, const_leaves, strict=False)] return jtu.tree_unflatten(treedef, merged) method = getattr(fn, "_ejk_method", "regular") validate_bwd = bool(getattr(fn, "_ejk_validate_backward", False)) if method == "shard_map" and not getattr(fn, "_ejk_validate_backward", False): validate_bwd = False def _time_forward(jitted: bool = True) -> float: """Time forward-only execution with optional JIT compilation. Args: jitted: If True, JIT-compile the function before timing. Falls back to non-jitted execution if False. Returns: Average execution time per iteration in seconds. """ def core(*arrs): """Reconstruct args/kwargs from array leaves and call the target function.""" (aa, kk) = _restore_args_kwargs(arrs) return fn(*aa, **kk) if jitted: c = jax.jit(core).lower(*arr0).compile() for _ in range(self.warmup): _block_all(c(*arr0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(c(*arr0)) return (time.perf_counter() - t0) / self.iters else: for _ in range(self.warmup): _block_all(core(*arr0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(core(*arr0)) return (time.perf_counter() - t0) / self.iters if validate_bwd: def _is_diff(x): """Check if a value is differentiable (floating-point or complex type).""" try: dt = np.dtype(getattr(x, "dtype", None)) return np.issubdtype(dt, np.inexact) except Exception: return False diff_mask = [_is_diff(x) for x in arr0] has_diff = any(diff_mask) if not has_diff: try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False) def _split(arrs): """Split array leaves into differentiable and non-differentiable groups.""" theta, nondiff = [], [] for m, v in zip(diff_mask, arrs, strict=False): (theta if m else nondiff).append(v) return tuple(theta), tuple(nondiff) def _merge(theta, nondiff): """Merge differentiable and non-differentiable arrays back into original order.""" it_t, it_n = iter(theta), iter(nondiff) return tuple(next(it_t) if m else next(it_n) for m in diff_mask) def loss(theta, nondiff): """Compute scalar loss for backward pass validation timing.""" arrs = _merge(theta, nondiff) (aa, kk) = _restore_args_kwargs(arrs) y = fn(*aa, **kk) return jnp.sum(y) try: grad_core = jax.jit(jax.grad(loss, argnums=0)) theta0, nondiff0 = _split(arr0) c = grad_core.lower(theta0, nondiff0).compile() for _ in range(self.warmup): _block_all(c(theta0, nondiff0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(c(theta0, nondiff0)) return (time.perf_counter() - t0) / self.iters except Exception as e: msg = str(e) if ("Cannot apply JAX transformations" in msg) or ("Leaked trace" in msg): try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False) raise try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False)
[docs] def autotune(self, make_fn, args, kwargs, candidates: Iterable[Cfg]) -> Cfg: """Benchmark all candidate configurations and return the fastest one. Tests each candidate configuration by measuring its execution time and selects the configuration with the lowest average execution time. Args: make_fn: Factory function that creates a function given a config args: Positional arguments for the function being benchmarked kwargs: Keyword arguments for the function being benchmarked candidates: Iterable of candidate configurations to test Returns: The configuration that achieved the fastest execution time Raises: RuntimeError: If no candidates are provided for testing """ best_cfg, best_t = None, float("inf") last_err = None for cfg in candidates: try: t = self.measure(make_fn(cfg), *args, **kwargs) if os.getenv("EJKERNEL_LOG_AUTOTUNE", "0") == "1": autotune_logger.info(pprint.pformat({"config": cfg, "time": t})) except Exception as e: last_err = e continue if t < best_t: best_cfg, best_t = cfg, t if best_cfg is None: if last_err: traceback.print_exception(last_err) autotune_logger.warning("All candidates failed during autotune; falling back to heuristics.") return None if os.getenv("EJKERNEL_LOG_AUTOTUNE", "0") == "1": autotune_logger.info(pprint.pformat({"best_config": best_cfg, "best_time": best_t})) return best_cfg
[docs]class ConfigSelectorChain(Generic[Cfg, Out]): """Multi-tier configuration selection system with fallback chain. Selection order: 1. Override (explicit configuration provided) 2. Overlay (temporary context-specific overrides) 3. In-memory cache (fast lookup for recently used configs) 4. Persistent cache (disk-based storage across runs) 5. Autotune (benchmark candidates to find optimal config) 6. Heuristics (kernel-provided default configuration) 7. Error (no configuration available) Attributes: cache: In-memory configuration cache policy: Autotuning behavior policy tuner: Performance benchmarking tool persistent: Optional disk-based cache persist_autotune: Whether to save autotuned configs to persistent storage on_event: Optional callback for selection events forbid_reautotune: Prevent re-autotuning the same operation """ def __init__( self, cache: ConfigCache[Cfg], policy: AutotunePolicy | None = None, tuner: Tuner[Cfg] | None = None, persistent: PersistentCache[Cfg] | None = None, persist_autotune: bool = True, on_event: callable | None = None, forbid_reautotune: bool = True, ): """Initialize configuration selector with cache and policy settings. Args: cache: In-memory configuration cache for fast lookups policy: Autotuning behavior policy (default: AutotunePolicy()) tuner: Performance benchmarking tool (default: Tuner()) persistent: Optional disk-based cache for cross-run persistence persist_autotune: Save autotuned configs to persistent storage (default: True) on_event: Optional callback for selection events (monitoring/debugging) forbid_reautotune: Prevent re-autotuning same operation (default: True) """ self.cache = cache self.policy = policy or AutotunePolicy() self.tuner = tuner or Tuner() self.persistent = persistent self.persist_autotune = persist_autotune self.on_event = on_event self.forbid_reautotune = forbid_reautotune self._autotuned_keys: set[tuple[str, str, str]] = set()
[docs] def choose(self, inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) -> Cfg: """Select optimal configuration using the fallback hierarchy. Implements the complete configuration selection algorithm, trying each method in order until a suitable configuration is found. Selection Priority (highest to lowest): 1. Override: Explicit configuration in invocation 2. Overlay: Temporary context-specific overrides 3. Memory Cache: Previously computed optimal configurations 4. Persistent Cache: Disk-stored configurations from previous runs 5. Autotune: Benchmark candidates to find optimal configuration 6. Heuristics: Kernel-provided default configuration Args: inv: Function invocation containing arguments and context kernel: Kernel implementation with candidate configurations Returns: Optimal configuration for this invocation Raises: RuntimeError: If no configuration can be determined """ dev = device_fingerprint() op_id = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}" call_key = inv.make_key(kernel.key_builder) if inv.override_cfg is not None: cfg = inv.override_cfg self._emit("override", device=dev, op_id=op_id, call_key=call_key, cfg=cfg) self.cache.put(dev, op_id, call_key, cfg) if self.persistent is not None: self.persistent.put(dev, op_id, call_key, cfg) return cfg for overlay in reversed(_cache_overlay.get()): if (cfg := overlay.get((dev, op_id, call_key))) is not None: self._emit("overlay_hit", device=dev, op_id=op_id, call_key=call_key, cfg=cfg) return cfg if (cfg := self.cache.get(dev, op_id, call_key)) is not None: self._emit( "cache_hit", level="memory", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, ) return cfg if self.persistent is not None: if (cfg := self.persistent.get(dev, op_id, call_key)) is not None: self._emit( "cache_hit", level="persistent", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, ) self.cache.put(dev, op_id, call_key, cfg) return cfg if self.policy.cache_miss_fallback == "autotune" and self.policy.allow_autotune: if self.forbid_reautotune and (dev, op_id, call_key) in self._autotuned_keys: raise RuntimeError(f"Re-autotune requested for {(dev, op_id, call_key)}") platform = get_device_platform() context = "shard_map" if inv.method == "shard_map" else None candidate_cfgs_method = _get_platform_method(kernel, "candidate_cfgs", platform, context) if candidate_cfgs_method: candidates = tuple(candidate_cfgs_method(inv)) else: candidates = tuple(kernel.candidate_cfgs(inv)) self._emit( "autotune_start", device=dev, op_id=op_id, call_key=call_key, candidates=len(candidates), platform=platform, method=inv.method, ) kw = dict(inv.kwargs) def _is_arrayish(x) -> bool: """Check if a value is an array-like object for argument partitioning.""" return isinstance(x, jax.Array | np.ndarray) or isinstance(x, jcore.Tracer) static_fun_kwargs = {k: v for k, v in kw.items() if callable(v)} dyn_kwargs = kw validate_backward = self.policy.validate_backward and _is_backward_autotune_enabled() if inv.method == "shard_map": if not hasattr(kernel, "create_shard_map_wrapper"): raise RuntimeError( f"Kernel {kernel.op_id} does not implement create_shard_map_wrapper for shard_map benchmarking" ) def mk(c, _static=static_fun_kwargs): """Create a shard_map-wrapped function for benchmarking a specific configuration.""" def f(*a, **k): """Execute the shard_map wrapper with the bound config and process callback.""" callback = None eagers = kernel.create_shard_map_wrapper( *a, cfg=c, mesh=inv.mesh, in_specs=inv.in_specs, out_specs=inv.out_specs, check_vma=inv.check_vma, **(k | _static), ) if len(eagers) == 2: shard_map_fn, call_args = eagers elif len(eagers) == 3: shard_map_fn, call_args, callback = eagers outs = shard_map_fn(*call_args) if callback is not None: outs = callback(outs, cfg=c) return outs f._ejk_method = "shard_map" if validate_backward and getattr(kernel, "supports_grad_validation", False): f._ejk_validate_backward = True return f else: run_method = _get_platform_method(kernel, "run", platform, context) or kernel.run def mk(c, _run=run_method, _static=static_fun_kwargs): """Create a function that executes the kernel run method with a specific config.""" def f(*a, **k): """Execute the run method with the bound configuration and static kwargs.""" return _run(*a, cfg=c, **(k | _static)) f._ejk_method = "regular" if validate_backward: f._ejk_validate_backward = True return f best = self.tuner.autotune(mk, inv.args, dyn_kwargs, candidates) if best is not None: self._autotuned_keys.add((dev, op_id, call_key)) self.cache.put(dev, op_id, call_key, best) if self.persistent is not None and self.persist_autotune: self.persistent.put(dev, op_id, call_key, best) self._emit( "autotune_finish", device=dev, op_id=op_id, call_key=call_key, cfg=best, platform=platform, method=inv.method, ) return best if self.policy.allow_heuristics: platform = get_device_platform() context = "shard_map" if inv.method == "shard_map" else None heuristic_cfg_method = _get_platform_method(kernel, "heuristic_cfg", platform, context) if heuristic_cfg_method: cfg = heuristic_cfg_method(inv) else: cfg = kernel.heuristic_cfg(inv) self._emit( "heuristics", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, platform=platform, method=inv.method, ) self.cache.put(dev, op_id, call_key, cfg) if self.persistent is not None and self.persist_autotune: self.persistent.put(dev, op_id, call_key, cfg) return cfg self._emit("error", device=dev, op_id=op_id, call_key=call_key, reason="no_config") raise RuntimeError("No config found: override/overlay/cache/persistent/autotune/heuristics all unavailable.")
def _emit(self, event: str, **data): """Emit selection event for monitoring and debugging. Calls the configured event callback with selection information. Args: event: Event type (e.g., 'cache_hit', 'autotune_start', 'error') **data: Additional event data (device, op_id, call_key, etc.) """ if self.on_event: self.on_event(event, **data)