ejkernel.ops.config.selection#

Configuration selection and autotuning system for kernel optimization.

This module provides a comprehensive configuration selection framework that intelligently chooses optimal kernel configurations through a multi-tier fallback chain. The system prioritizes cached results while supporting automatic performance optimization when needed.

Key Components:

ConfigSelectorChain: Main selection coordinator with fallback hierarchy AutotunePolicy: Configuration policy for autotuning behavior Tuner: Performance benchmarking and autotuning engine policy_override: Context manager for temporary policy changes

Selection Hierarchy (in order of priority):
  1. Override: Explicit configuration provided by caller

  2. Overlay: Temporary context-specific configuration overrides

  3. Memory Cache: Fast lookup for recently used configurations

  4. Persistent Cache: Disk-based storage across program runs

  5. Autotune: Benchmark candidates to find optimal configuration

  6. Heuristics: Kernel-provided default configuration

  7. Error: No configuration available (throws exception)

This design ensures optimal performance by:
  • Prioritizing fastest lookup methods (memory cache)

  • Preserving optimization results across runs (persistent cache)

  • Automatically finding optimal configurations (autotuning)

  • Providing sensible defaults (heuristics) as fallback

Example Usage:
>>> cache = ConfigCache()
>>> policy = AutotunePolicy(allow_autotune=True)
>>> selector = ConfigSelectorChain(cache, policy)
>>>
>>>
>>> config = selector.choose(invocation, kernel)
>>>
>>>
>>> with policy_override(selector, allow_autotune=False):
...     config = selector.choose(invocation, kernel)
class ejkernel.ops.config.selection.AutotunePolicy(allow_autotune: bool = True, allow_heuristics: bool = True, cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune', validate_backward: bool = False)[source]#

Bases: object

Configuration policy for autotuning behavior.

Controls how the configuration selection system behaves when making optimization decisions, including whether to run autotuning, use heuristics, and validate backward pass correctness.

allow_autotune#

Whether autotuning is permitted. When True, the system can benchmark multiple configurations to find the optimal one.

Type

bool

allow_heuristics#

Whether heuristic configurations are allowed as a fallback when no cached configuration is available.

Type

bool

cache_miss_fallback#

Strategy when no cached config is found. Either “autotune” to benchmark candidates or “heuristics” to use defaults.

Type

Literal[‘autotune’, ‘heuristics’]

validate_backward#

Whether to validate backward pass during autotuning. When True, autotuning will measure gradient computation time in addition to forward pass, ensuring the selected configuration performs well for training workloads.

Type

bool

allow_autotune: bool = True#
allow_heuristics: bool = True#
cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune'#
validate_backward: bool = False#
class ejkernel.ops.config.selection.ConfigSelectorChain(cache: ConfigCache[Cfg], policy: AutotunePolicy | None = None, tuner: Tuner[Cfg] | None = None, persistent: PersistentCache[Cfg] | None = None, persist_autotune: bool = True, on_event: callable | None = None, forbid_reautotune: bool = True)[source]#

Bases: Generic[Cfg, Out]

Multi-tier configuration selection system with fallback chain.

Selection order: 1. Override (explicit configuration provided) 2. Overlay (temporary context-specific overrides) 3. In-memory cache (fast lookup for recently used configs) 4. Persistent cache (disk-based storage across runs) 5. Autotune (benchmark candidates to find optimal config) 6. Heuristics (kernel-provided default configuration) 7. Error (no configuration available)

cache#

In-memory configuration cache

policy#

Autotuning behavior policy

tuner#

Performance benchmarking tool

persistent#

Optional disk-based cache

persist_autotune#

Whether to save autotuned configs to persistent storage

on_event#

Optional callback for selection events

forbid_reautotune#

Prevent re-autotuning the same operation

choose(inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) Cfg[source]#

Select optimal configuration using the fallback hierarchy.

Implements the complete configuration selection algorithm, trying each method in order until a suitable configuration is found.

Selection Priority (highest to lowest): 1. Override: Explicit configuration in invocation 2. Overlay: Temporary context-specific overrides 3. Memory Cache: Previously computed optimal configurations 4. Persistent Cache: Disk-stored configurations from previous runs 5. Autotune: Benchmark candidates to find optimal configuration 6. Heuristics: Kernel-provided default configuration

Parameters
  • inv – Function invocation containing arguments and context

  • kernel – Kernel implementation with candidate configurations

Returns

Optimal configuration for this invocation

Raises

RuntimeError – If no configuration can be determined

class ejkernel.ops.config.selection.Tuner(warmup=1, iters=3)[source]#

Bases: Generic[Cfg]

Performance benchmarking and autotuning for kernel configurations.

Measures execution time of different configurations and selects the fastest one.

warmup#

Number of warmup iterations before timing

iters#

Number of timing iterations to average over

autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) Cfg[source]#

Benchmark all candidate configurations and return the fastest one.

Tests each candidate configuration by measuring its execution time and selects the configuration with the lowest average execution time.

Parameters
  • make_fn – Factory function that creates a function given a config

  • args – Positional arguments for the function being benchmarked

  • kwargs – Keyword arguments for the function being benchmarked

  • candidates – Iterable of candidate configurations to test

Returns

The configuration that achieved the fastest execution time

Raises

RuntimeError – If no candidates are provided for testing

measure(fn, *args, **kwargs) float[source]#

Measure average execution time with optional backward validation.

Deep-flatten (args, kwargs) so only array-like leaves are dynamic: - Arrays or JAX tracers become dynamic parameters to the jitted function. - Everything else (dtype, strings, bools, callables, nested containers)

is captured as Python constants in the closure.

  • Tracer-like arrays (e.g., ShardMapTracer, DynamicJaxprTracer) are converted to concrete zeros of the same shape/dtype before compile and timing.

  • If _ejk_validate_backward=True, we differentiate a scalar loss w.r.t. float/complex array leaves only; others are treated as non-diff.

  • If a kernel uses precompiled functions that can’t be transformed, we fall back to forward-only timing, and if needed, to non-jitted forward timing.

Parameters
  • fn – Function to measure (possibly tagged with _ejk_validate_backward)

  • *args – Positional arguments

  • **kwargs – Keyword arguments

Returns

Average execution time per iteration in seconds

class ejkernel.ops.config.selection.forward_autotune_only[source]#

Bases: object

Context manager that disables backward validation during autotuning.

While active, autotune measurements run forward-only even when AutotunePolicy.validate_backward is True. This keeps autotuning focused on forward latency and avoids gradient-timing overhead.

Example

>>> with forward_autotune_only():
...     cfg = selector.choose(inv, kernel)
class ejkernel.ops.config.selection.policy_override(selector: ConfigSelectorChain, **updates)[source]#

Bases: object

Context manager for temporarily overriding autotuning policy settings.

Allows temporary modification of AutotunePolicy attributes within a context, automatically restoring the original values when exiting the context.

This is useful for: - Disabling autotuning for specific operations - Forcing use of heuristics during debugging - Testing different policy configurations

Parameters
  • selector – ConfigSelectorChain instance to modify

  • **updates – Policy attributes to override with new values

Example

>>> with policy_override(selector, allow_autotune=False):
...     result = executor(kernel, *args)
>>>
>>> with policy_override(selector, cache_miss_fallback="heuristics"):
...     config = selector.choose(inv, kernel)