Source code for ejkernel.ops.config.selection

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Configuration selection and autotuning system for kernel optimization.

This module provides a comprehensive configuration selection framework that
intelligently chooses optimal kernel configurations through a multi-tier
fallback chain. The system prioritizes cached results while supporting
automatic performance optimization when needed.

Key Components:
    ConfigSelectorChain: Main selection coordinator with fallback hierarchy
    AutotunePolicy: Configuration policy for autotuning behavior
    Tuner: Performance benchmarking and autotuning engine
    policy_override: Context manager for temporary policy changes

Selection Hierarchy (in order of priority):
    1. Override: Explicit configuration provided by caller
    2. Overlay: Temporary context-specific configuration overrides
    3. Memory Cache: Fast lookup for recently used configurations
    4. Persistent Cache: Disk-based storage across program runs
    5. Autotune: Benchmark candidates to find optimal configuration
    6. Heuristics: Kernel-provided default configuration
    7. Error: No configuration available (throws exception)

This design ensures optimal performance by:
    - Prioritizing fastest lookup methods (memory cache)
    - Preserving optimization results across runs (persistent cache)
    - Automatically finding optimal configurations (autotuning)
    - Providing sensible defaults (heuristics) as fallback

Example Usage:
    >>> cache = ConfigCache()
    >>> policy = AutotunePolicy(allow_autotune=True)
    >>> selector = ConfigSelectorChain(cache, policy)
    >>>
    >>>
    >>> config = selector.choose(invocation, kernel)
    >>>
    >>>
    >>> with policy_override(selector, allow_autotune=False):
    ...     config = selector.choose(invocation, kernel)
"""

from __future__ import annotations

import os
import pprint
import time
import traceback
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
from jax import core as jcore
from jax import tree_util as jtu

from ejkernel.loggings import get_logger

from ..core import Invocation, Kernel, _get_platform_method
from ..utils.fingerprint import device_fingerprint, get_device_platform
from .cache import ConfigCache, _cache_overlay
from .persistent import PersistentCache

Cfg = TypeVar("Cfg")
Out = TypeVar("Out")

autotune_logger = get_logger("ejKernel-Selection")


[docs]@dataclass class AutotunePolicy: """Configuration policy for autotuning behavior. Attributes: allow_autotune: Whether autotuning is permitted allow_heuristics: Whether heuristic configurations are allowed cache_miss_fallback: Strategy when no cached config is found ("heuristics" or "autotune") """ allow_autotune: bool = True allow_heuristics: bool = True cache_miss_fallback: Literal["autotune", "heuristics"] = "autotune" validate_backward: bool = False
[docs]class policy_override: """Context manager for temporarily overriding autotuning policy settings. Allows temporary modification of AutotunePolicy attributes within a context, automatically restoring the original values when exiting the context. This is useful for: - Disabling autotuning for specific operations - Forcing use of heuristics during debugging - Testing different policy configurations Args: selector: ConfigSelectorChain instance to modify **updates: Policy attributes to override with new values Example: >>> with policy_override(selector, allow_autotune=False): ... result = executor(kernel, *args) >>> >>> with policy_override(selector, cache_miss_fallback="heuristics"): ... config = selector.choose(inv, kernel) """ def __init__(self, selector: ConfigSelectorChain, **updates): """Initialize policy override context manager. Args: selector: ConfigSelectorChain to modify **updates: Policy attributes to override """ self.selector = selector self.updates = updates self._prev = {} def __enter__(self): """Enter context and apply policy overrides. Returns: Self for use in with statements """ for k, v in self.updates.items(): self._prev[k] = getattr(self.selector.policy, k) setattr(self.selector.policy, k, v) return self def __exit__(self, *exc): """Exit context and restore original policy values. Args: *exc: Exception information (ignored) """ for k, v in self._prev.items(): setattr(self.selector.policy, k, v)
[docs]class Tuner(Generic[Cfg]): """Performance benchmarking and autotuning for kernel configurations. Measures execution time of different configurations and selects the fastest one. Attributes: warmup: Number of warmup iterations before timing iters: Number of timing iterations to average over """ def __init__(self, warmup=1, iters=3): self.warmup, self.iters = warmup, iters
[docs] def measure(self, fn, *args, **kwargs) -> float: """Measure average execution time with optional backward validation. Deep-flatten (args, kwargs) so only array-like leaves are dynamic: - Arrays or JAX tracers become dynamic parameters to the jitted function. - Everything else (dtype, strings, bools, callables, nested containers) is captured as Python constants in the closure. - Tracer-like arrays (e.g., ShardMapTracer, DynamicJaxprTracer) are converted to concrete zeros of the same shape/dtype before compile and timing. - If _ejk_validate_backward=True, we differentiate a scalar loss w.r.t. float/complex array leaves only; others are treated as non-diff. - If a kernel uses precompiled functions that can't be transformed, we fall back to forward-only timing, and if needed, to non-jitted forward timing. Args: fn: Function to measure (possibly tagged with _ejk_validate_backward) *args: Positional arguments **kwargs: Keyword arguments Returns: Average execution time per iteration in seconds """ def _is_arrayish(x) -> bool: return isinstance(x, jax.Array | np.ndarray) or isinstance(x, jcore.Tracer) def _to_concrete(x): if isinstance(x, jax.Array | np.ndarray): return x shape = getattr(x, "shape", None) dtype = getattr(x, "dtype", None) aval = getattr(x, "aval", None) if (shape is None or dtype is None) and aval is not None: shape = getattr(aval, "shape", None) dtype = getattr(aval, "dtype", None) if shape is not None and dtype is not None: return jnp.zeros(shape, dtype) return jnp.asarray(x) def _block_all(x): return jtu.tree_map(lambda t: t.block_until_ready() if hasattr(t, "block_until_ready") else t, x) leaves, treedef = jtu.tree_flatten((args, kwargs)) is_arr = [_is_arrayish(x) for x in leaves] const_leaves = [None if m else x for m, x in zip(is_arr, leaves, strict=False)] arr_leaves = [x for m, x in zip(is_arr, leaves, strict=False) if m] arr0 = tuple(_to_concrete(x) for x in arr_leaves) def _restore_args_kwargs(array_leaves): it = iter(array_leaves) merged = [next(it) if m else v for m, v in zip(is_arr, const_leaves, strict=False)] return jtu.tree_unflatten(treedef, merged) method = getattr(fn, "_ejk_method", "regular") validate_bwd = bool(getattr(fn, "_ejk_validate_backward", False)) if method == "shard_map" and not getattr(fn, "_ejk_validate_backward", False): validate_bwd = False def _time_forward(jitted: bool = True) -> float: def core(*arrs): (aa, kk) = _restore_args_kwargs(arrs) return fn(*aa, **kk) if jitted: c = jax.jit(core).lower(*arr0).compile() for _ in range(self.warmup): _block_all(c(*arr0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(c(*arr0)) return (time.perf_counter() - t0) / self.iters else: for _ in range(self.warmup): _block_all(core(*arr0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(core(*arr0)) return (time.perf_counter() - t0) / self.iters if validate_bwd: def _is_diff(x): try: dt = np.dtype(getattr(x, "dtype", None)) return np.issubdtype(dt, np.inexact) except Exception: return False diff_mask = [_is_diff(x) for x in arr0] has_diff = any(diff_mask) if not has_diff: try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False) def _split(arrs): theta, nondiff = [], [] for m, v in zip(diff_mask, arrs, strict=False): (theta if m else nondiff).append(v) return tuple(theta), tuple(nondiff) def _merge(theta, nondiff): it_t, it_n = iter(theta), iter(nondiff) return tuple(next(it_t) if m else next(it_n) for m in diff_mask) def loss(theta, nondiff): arrs = _merge(theta, nondiff) (aa, kk) = _restore_args_kwargs(arrs) y = fn(*aa, **kk) return jnp.sum(y) try: grad_core = jax.jit(jax.grad(loss, argnums=0)) theta0, nondiff0 = _split(arr0) c = grad_core.lower(theta0, nondiff0).compile() for _ in range(self.warmup): _block_all(c(theta0, nondiff0)) t0 = time.perf_counter() for _ in range(self.iters): _block_all(c(theta0, nondiff0)) return (time.perf_counter() - t0) / self.iters except Exception as e: msg = str(e) if ("Cannot apply JAX transformations" in msg) or ("Leaked trace" in msg): try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False) raise try: return _time_forward(jitted=True) except Exception: return _time_forward(jitted=False)
[docs] def autotune(self, make_fn, args, kwargs, candidates: Iterable[Cfg]) -> Cfg: """Benchmark all candidate configurations and return the fastest one. Tests each candidate configuration by measuring its execution time and selects the configuration with the lowest average execution time. Args: make_fn: Factory function that creates a function given a config args: Positional arguments for the function being benchmarked kwargs: Keyword arguments for the function being benchmarked candidates: Iterable of candidate configurations to test Returns: The configuration that achieved the fastest execution time Raises: RuntimeError: If no candidates are provided for testing """ best_cfg, best_t = None, float("inf") last_err = None for cfg in candidates: try: t = self.measure(make_fn(cfg), *args, **kwargs) if os.getenv("EJKERNEL_LOG_AUTOTUNE", "0") == "1": autotune_logger.info(pprint.pformat({"config": cfg, "time": t})) except Exception as e: last_err = e continue if t < best_t: best_cfg, best_t = cfg, t if best_cfg is None: if last_err: traceback.print_exception(last_err) autotune_logger.warning("All candidates failed during autotune; falling back to heuristics.") return None if os.getenv("EJKERNEL_LOG_AUTOTUNE", "0") == "1": autotune_logger.info(pprint.pformat({"best_config": best_cfg, "best_time": best_t})) return best_cfg
[docs]class ConfigSelectorChain(Generic[Cfg, Out]): """Multi-tier configuration selection system with fallback chain. Selection order: 1. Override (explicit configuration provided) 2. Overlay (temporary context-specific overrides) 3. In-memory cache (fast lookup for recently used configs) 4. Persistent cache (disk-based storage across runs) 5. Autotune (benchmark candidates to find optimal config) 6. Heuristics (kernel-provided default configuration) 7. Error (no configuration available) Attributes: cache: In-memory configuration cache policy: Autotuning behavior policy tuner: Performance benchmarking tool persistent: Optional disk-based cache persist_autotune: Whether to save autotuned configs to persistent storage on_event: Optional callback for selection events forbid_reautotune: Prevent re-autotuning the same operation """ def __init__( self, cache: ConfigCache[Cfg], policy: AutotunePolicy | None = None, tuner: Tuner[Cfg] | None = None, persistent: PersistentCache[Cfg] | None = None, persist_autotune: bool = True, on_event: callable | None = None, forbid_reautotune: bool = True, ): self.cache = cache self.policy = policy or AutotunePolicy() self.tuner = tuner or Tuner() self.persistent = persistent self.persist_autotune = persist_autotune self.on_event = on_event self.forbid_reautotune = forbid_reautotune self._autotuned_keys: set[tuple[str, str, str]] = set()
[docs] def choose(self, inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) -> Cfg: """Select optimal configuration using the fallback hierarchy. Implements the complete configuration selection algorithm, trying each method in order until a suitable configuration is found. Selection Priority (highest to lowest): 1. Override: Explicit configuration in invocation 2. Overlay: Temporary context-specific overrides 3. Memory Cache: Previously computed optimal configurations 4. Persistent Cache: Disk-stored configurations from previous runs 5. Autotune: Benchmark candidates to find optimal configuration 6. Heuristics: Kernel-provided default configuration Args: inv: Function invocation containing arguments and context kernel: Kernel implementation with candidate configurations Returns: Optimal configuration for this invocation Raises: RuntimeError: If no configuration can be determined """ dev = device_fingerprint() op_id = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}" call_key = inv.make_key(kernel.key_builder) if inv.override_cfg is not None: cfg = inv.override_cfg self._emit("override", device=dev, op_id=op_id, call_key=call_key, cfg=cfg) self.cache.put(dev, op_id, call_key, cfg) if self.persistent is not None: self.persistent.put(dev, op_id, call_key, cfg) return cfg for overlay in reversed(_cache_overlay.get()): if (cfg := overlay.get((dev, op_id, call_key))) is not None: self._emit("overlay_hit", device=dev, op_id=op_id, call_key=call_key, cfg=cfg) return cfg if (cfg := self.cache.get(dev, op_id, call_key)) is not None: self._emit( "cache_hit", level="memory", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, ) return cfg if self.persistent is not None: if (cfg := self.persistent.get(dev, op_id, call_key)) is not None: self._emit( "cache_hit", level="persistent", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, ) self.cache.put(dev, op_id, call_key, cfg) return cfg if self.policy.cache_miss_fallback == "autotune" and self.policy.allow_autotune: if self.forbid_reautotune and (dev, op_id, call_key) in self._autotuned_keys: raise RuntimeError(f"Re-autotune requested for {(dev, op_id, call_key)}") platform = get_device_platform() context = "shard_map" if inv.method == "shard_map" else None candidate_cfgs_method = _get_platform_method(kernel, "candidate_cfgs", platform, context) if candidate_cfgs_method: candidates = tuple(candidate_cfgs_method(inv)) else: candidates = tuple(kernel.candidate_cfgs(inv)) self._emit( "autotune_start", device=dev, op_id=op_id, call_key=call_key, candidates=len(candidates), platform=platform, method=inv.method, ) kw = dict(inv.kwargs) def _is_arrayish(x) -> bool: return isinstance(x, jax.Array | np.ndarray) or isinstance(x, jcore.Tracer) static_fun_kwargs = {k: v for k, v in kw.items() if callable(v)} dyn_kwargs = kw if inv.method == "shard_map": if not hasattr(kernel, "create_shard_map_wrapper"): raise RuntimeError( f"Kernel {kernel.op_id} does not implement create_shard_map_wrapper for shard_map benchmarking" ) def mk(c, _static=static_fun_kwargs): def f(*a, **k): callback = None eagers = kernel.create_shard_map_wrapper( *a, cfg=c, mesh=inv.mesh, in_specs=inv.in_specs, out_specs=inv.out_specs, check_vma=inv.check_vma, **(k | _static), ) if len(eagers) == 2: shard_map_fn, call_args = eagers elif len(eagers) == 3: shard_map_fn, call_args, callback = eagers outs = shard_map_fn(*call_args) if callback is not None: outs = callback(outs, cfg=c) return outs f._ejk_method = "shard_map" if self.policy.validate_backward and getattr(kernel, "supports_grad_validation", False): f._ejk_validate_backward = True return f else: run_method = _get_platform_method(kernel, "run", platform, context) or kernel.run def mk(c, _run=run_method, _static=static_fun_kwargs): def f(*a, **k): return _run(*a, cfg=c, **(k | _static)) f._ejk_method = "regular" if self.policy.validate_backward: f._ejk_validate_backward = True return f best = self.tuner.autotune(mk, inv.args, dyn_kwargs, candidates) if best is not None: self._autotuned_keys.add((dev, op_id, call_key)) self.cache.put(dev, op_id, call_key, best) if self.persistent is not None and self.persist_autotune: self.persistent.put(dev, op_id, call_key, best) self._emit( "autotune_finish", device=dev, op_id=op_id, call_key=call_key, cfg=best, platform=platform, method=inv.method, ) return best if self.policy.allow_heuristics: platform = get_device_platform() context = "shard_map" if inv.method == "shard_map" else None heuristic_cfg_method = _get_platform_method(kernel, "heuristic_cfg", platform, context) if heuristic_cfg_method: cfg = heuristic_cfg_method(inv) else: cfg = kernel.heuristic_cfg(inv) self._emit( "heuristics", device=dev, op_id=op_id, call_key=call_key, cfg=cfg, platform=platform, method=inv.method, ) self.cache.put(dev, op_id, call_key, cfg) if self.persistent is not None and self.persist_autotune: self.persistent.put(dev, op_id, call_key, cfg) return cfg self._emit("error", device=dev, op_id=op_id, call_key=call_key, reason="no_config") raise RuntimeError("No config found: override/overlay/cache/persistent/autotune/heuristics all unavailable.")
def _emit(self, event: str, **data): """Emit selection event for monitoring and debugging. Calls the configured event callback with selection information. Args: event: Event type (e.g., 'cache_hit', 'autotune_start', 'error') **data: Additional event data (device, op_id, call_key, etc.) """ if self.on_event: self.on_event(event, **data)