ejkernel.ops.config.selection#

Configuration selection and autotuning system for kernel optimization.

This module provides a comprehensive configuration selection framework that intelligently chooses optimal kernel configurations through a multi-tier fallback chain. The system prioritizes cached results while supporting automatic performance optimization when needed.

Key Components:

ConfigSelectorChain: Main selection coordinator with fallback hierarchy AutotunePolicy: Configuration policy for autotuning behavior Tuner: Performance benchmarking and autotuning engine policy_override: Context manager for temporary policy changes

Selection Hierarchy (in order of priority):
  1. Override: Explicit configuration provided by caller

  2. Overlay: Temporary context-specific configuration overrides

  3. Memory Cache: Fast lookup for recently used configurations

  4. Persistent Cache: Disk-based storage across program runs

  5. Autotune: Benchmark candidates to find optimal configuration

  6. Heuristics: Kernel-provided default configuration

  7. Error: No configuration available (throws exception)

This design ensures optimal performance by:
  • Prioritizing fastest lookup methods (memory cache)

  • Preserving optimization results across runs (persistent cache)

  • Automatically finding optimal configurations (autotuning)

  • Providing sensible defaults (heuristics) as fallback

Example Usage:
>>> cache = ConfigCache()
>>> policy = AutotunePolicy(allow_autotune=True)
>>> selector = ConfigSelectorChain(cache, policy)
>>>
>>>
>>> config = selector.choose(invocation, kernel)
>>>
>>>
>>> with policy_override(selector, allow_autotune=False):
...     config = selector.choose(invocation, kernel)
class ejkernel.ops.config.selection.AutotunePolicy(allow_autotune: bool = True, allow_heuristics: bool = True, cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune', validate_backward: bool = False)[source]#

Bases: object

Configuration policy for autotuning behavior.

allow_autotune#

Whether autotuning is permitted

Type

bool

allow_heuristics#

Whether heuristic configurations are allowed

Type

bool

cache_miss_fallback#

Strategy when no cached config is found (“heuristics” or “autotune”)

Type

Literal[‘autotune’, ‘heuristics’]

allow_autotune: bool = True#
allow_heuristics: bool = True#
cache_miss_fallback: Literal['autotune', 'heuristics'] = 'autotune'#
validate_backward: bool = False#
class ejkernel.ops.config.selection.ConfigSelectorChain(cache: ConfigCache[Cfg], policy: AutotunePolicy | None = None, tuner: Tuner[Cfg] | None = None, persistent: PersistentCache[Cfg] | None = None, persist_autotune: bool = True, on_event: callable | None = None, forbid_reautotune: bool = True)[source]#

Bases: Generic[Cfg, Out]

Multi-tier configuration selection system with fallback chain.

Selection order: 1. Override (explicit configuration provided) 2. Overlay (temporary context-specific overrides) 3. In-memory cache (fast lookup for recently used configs) 4. Persistent cache (disk-based storage across runs) 5. Autotune (benchmark candidates to find optimal config) 6. Heuristics (kernel-provided default configuration) 7. Error (no configuration available)

cache#

In-memory configuration cache

policy#

Autotuning behavior policy

tuner#

Performance benchmarking tool

persistent#

Optional disk-based cache

persist_autotune#

Whether to save autotuned configs to persistent storage

on_event#

Optional callback for selection events

forbid_reautotune#

Prevent re-autotuning the same operation

choose(inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) Cfg[source]#

Select optimal configuration using the fallback hierarchy.

Implements the complete configuration selection algorithm, trying each method in order until a suitable configuration is found.

Selection Priority (highest to lowest): 1. Override: Explicit configuration in invocation 2. Overlay: Temporary context-specific overrides 3. Memory Cache: Previously computed optimal configurations 4. Persistent Cache: Disk-stored configurations from previous runs 5. Autotune: Benchmark candidates to find optimal configuration 6. Heuristics: Kernel-provided default configuration

Parameters
  • inv – Function invocation containing arguments and context

  • kernel – Kernel implementation with candidate configurations

Returns

Optimal configuration for this invocation

Raises

RuntimeError – If no configuration can be determined

class ejkernel.ops.config.selection.Tuner(warmup=1, iters=3)[source]#

Bases: Generic[Cfg]

Performance benchmarking and autotuning for kernel configurations.

Measures execution time of different configurations and selects the fastest one.

warmup#

Number of warmup iterations before timing

iters#

Number of timing iterations to average over

autotune(make_fn, args, kwargs, candidates: Iterable[Cfg]) Cfg[source]#

Benchmark all candidate configurations and return the fastest one.

Tests each candidate configuration by measuring its execution time and selects the configuration with the lowest average execution time.

Parameters
  • make_fn – Factory function that creates a function given a config

  • args – Positional arguments for the function being benchmarked

  • kwargs – Keyword arguments for the function being benchmarked

  • candidates – Iterable of candidate configurations to test

Returns

The configuration that achieved the fastest execution time

Raises

RuntimeError – If no candidates are provided for testing

measure(fn, *args, **kwargs) float[source]#

Measure average execution time with optional backward validation.

Deep-flatten (args, kwargs) so only array-like leaves are dynamic: - Arrays or JAX tracers become dynamic parameters to the jitted function. - Everything else (dtype, strings, bools, callables, nested containers)

is captured as Python constants in the closure.

  • Tracer-like arrays (e.g., ShardMapTracer, DynamicJaxprTracer) are converted to concrete zeros of the same shape/dtype before compile and timing.

  • If _ejk_validate_backward=True, we differentiate a scalar loss w.r.t. float/complex array leaves only; others are treated as non-diff.

  • If a kernel uses precompiled functions that can’t be transformed, we fall back to forward-only timing, and if needed, to non-jitted forward timing.

Parameters
  • fn – Function to measure (possibly tagged with _ejk_validate_backward)

  • *args – Positional arguments

  • **kwargs – Keyword arguments

Returns

Average execution time per iteration in seconds

class ejkernel.ops.config.selection.policy_override(selector: ConfigSelectorChain, **updates)[source]#

Bases: object

Context manager for temporarily overriding autotuning policy settings.

Allows temporary modification of AutotunePolicy attributes within a context, automatically restoring the original values when exiting the context.

This is useful for: - Disabling autotuning for specific operations - Forcing use of heuristics during debugging - Testing different policy configurations

Parameters
  • selector – ConfigSelectorChain instance to modify

  • **updates – Policy attributes to override with new values

Example

>>> with policy_override(selector, allow_autotune=False):
...     result = executor(kernel, *args)
>>>
>>> with policy_override(selector, cache_miss_fallback="heuristics"):
...     config = selector.choose(inv, kernel)