ejkernel.ops.config.selection#
Configuration selection and autotuning system for kernel optimization.
This module provides a comprehensive configuration selection framework that intelligently chooses optimal kernel configurations through a multi-tier fallback chain. The system prioritizes cached results while supporting automatic performance optimization when needed.
- Key Components:
ConfigSelectorChain: Main selection coordinator with fallback hierarchy AutotunePolicy: Configuration policy for autotuning behavior Tuner: Performance benchmarking and autotuning engine policy_override: Context manager for temporary policy changes
- Selection Hierarchy (in order of priority):
Override: Explicit configuration provided by caller
Overlay: Temporary context-specific configuration overrides
Memory Cache: Fast lookup for recently used configurations
Persistent Cache: Disk-based storage across program runs
Autotune: Benchmark candidates to find optimal configuration
Heuristics: Kernel-provided default configuration
Error: No configuration available (throws exception)
- This design ensures optimal performance by:
Prioritizing fastest lookup methods (memory cache)
Preserving optimization results across runs (persistent cache)
Automatically finding optimal configurations (autotuning)
Providing sensible defaults (heuristics) as fallback
- Example Usage:
>>> cache = ConfigCache() >>> policy = AutotunePolicy(allow_autotune=True) >>> selector = ConfigSelectorChain(cache, policy) >>> >>> >>> config = selector.choose(invocation, kernel) >>> >>> >>> with policy_override(selector, allow_autotune=False): ... config = selector.choose(invocation, kernel)
- class ejkernel.ops.config.selection.AutotunePolicy(allow_autotune: bool = True, allow_heuristics: bool = True, cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune', validate_backward: bool = False)[source]#
Bases:
objectConfiguration policy for autotuning behavior.
- allow_autotune#
Whether autotuning is permitted
- Type
bool
- allow_heuristics#
Whether heuristic configurations are allowed
- Type
bool
- cache_miss_fallback#
Strategy when no cached config is found (“heuristics” or “autotune”)
- Type
Literal[‘autotune’, ‘heuristics’]
- allow_autotune: bool = True#
- allow_heuristics: bool = True#
- cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune'#
- validate_backward: bool = False#
- class ejkernel.ops.config.selection.ConfigSelectorChain(cache: ConfigCache[Cfg], policy: AutotunePolicy | None = None, tuner: Tuner[Cfg] | None = None, persistent: PersistentCache[Cfg] | None = None, persist_autotune: bool = True, on_event: callable | None = None, forbid_reautotune: bool = True)[source]#
Bases:
Generic[Cfg,Out]Multi-tier configuration selection system with fallback chain.
Selection order: 1. Override (explicit configuration provided) 2. Overlay (temporary context-specific overrides) 3. In-memory cache (fast lookup for recently used configs) 4. Persistent cache (disk-based storage across runs) 5. Autotune (benchmark candidates to find optimal config) 6. Heuristics (kernel-provided default configuration) 7. Error (no configuration available)
- cache#
In-memory configuration cache
- policy#
Autotuning behavior policy
- tuner#
Performance benchmarking tool
- persistent#
Optional disk-based cache
- persist_autotune#
Whether to save autotuned configs to persistent storage
- on_event#
Optional callback for selection events
- forbid_reautotune#
Prevent re-autotuning the same operation
- choose(inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) Cfg[source]#
Select optimal configuration using the fallback hierarchy.
Implements the complete configuration selection algorithm, trying each method in order until a suitable configuration is found.
Selection Priority (highest to lowest): 1. Override: Explicit configuration in invocation 2. Overlay: Temporary context-specific overrides 3. Memory Cache: Previously computed optimal configurations 4. Persistent Cache: Disk-stored configurations from previous runs 5. Autotune: Benchmark candidates to find optimal configuration 6. Heuristics: Kernel-provided default configuration
- Parameters
inv – Function invocation containing arguments and context
kernel – Kernel implementation with candidate configurations
- Returns
Optimal configuration for this invocation
- Raises
RuntimeError – If no configuration can be determined
- class ejkernel.ops.config.selection.Tuner(warmup=1, iters=3)[source]#
Bases:
Generic[Cfg]Performance benchmarking and autotuning for kernel configurations.
Measures execution time of different configurations and selects the fastest one.
- warmup#
Number of warmup iterations before timing
- iters#
Number of timing iterations to average over
- autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) Cfg[source]#
Benchmark all candidate configurations and return the fastest one.
Tests each candidate configuration by measuring its execution time and selects the configuration with the lowest average execution time.
- Parameters
make_fn – Factory function that creates a function given a config
args – Positional arguments for the function being benchmarked
kwargs – Keyword arguments for the function being benchmarked
candidates – Iterable of candidate configurations to test
- Returns
The configuration that achieved the fastest execution time
- Raises
RuntimeError – If no candidates are provided for testing
- measure(fn, *args, **kwargs) float[source]#
Measure average execution time with optional backward validation.
Deep-flatten (args, kwargs) so only array-like leaves are dynamic: - Arrays or JAX tracers become dynamic parameters to the jitted function. - Everything else (dtype, strings, bools, callables, nested containers)
is captured as Python constants in the closure.
Tracer-like arrays (e.g., ShardMapTracer, DynamicJaxprTracer) are converted to concrete zeros of the same shape/dtype before compile and timing.
If _ejk_validate_backward=True, we differentiate a scalar loss w.r.t. float/complex array leaves only; others are treated as non-diff.
If a kernel uses precompiled functions that can’t be transformed, we fall back to forward-only timing, and if needed, to non-jitted forward timing.
- Parameters
fn – Function to measure (possibly tagged with _ejk_validate_backward)
*args – Positional arguments
**kwargs – Keyword arguments
- Returns
Average execution time per iteration in seconds
- class ejkernel.ops.config.selection.policy_override(selector: ConfigSelectorChain, **updates)[source]#
Bases:
objectContext manager for temporarily overriding autotuning policy settings.
Allows temporary modification of AutotunePolicy attributes within a context, automatically restoring the original values when exiting the context.
This is useful for: - Disabling autotuning for specific operations - Forcing use of heuristics during debugging - Testing different policy configurations
- Parameters
selector – ConfigSelectorChain instance to modify
**updates – Policy attributes to override with new values
Example
>>> with policy_override(selector, allow_autotune=False): ... result = executor(kernel, *args) >>>
>>> with policy_override(selector, cache_miss_fallback="heuristics"): ... config = selector.choose(inv, kernel)