# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main execution engine for kernels with configuration management.
This module provides the Executor class, which serves as the central orchestrator
for running kernel operations with automatic configuration selection, custom
gradient support, and comprehensive profiling capabilities.
Key Components:
Executor: Main execution engine coordinating the entire execution pipeline
ConfigChooser: Protocol defining configuration selection interface
The Executor handles:
- Argument preprocessing via kernel.prepare()
- Configuration selection through ConfigChooser strategies
- Custom VJP (Vector-Jacobian Product) implementation for gradients
- Profiling metadata injection for performance analysis
- Invocation recording for batch optimization
- JAX compilation with pre-selected configurations
Execution Flow:
1. Preprocess arguments using kernel.prepare()
2. Create Invocation object with argument metadata
3. Select configuration via ConfigChooser.choose()
4. Set up custom VJP if kernel implements it
5. Add profiling metadata based on environment settings
6. Execute kernel with chosen configuration
7. Record invocation for future optimization (if enabled)
Environment Variables:
EJKERNEL_OPS_RECORD: Set to "1" to enable invocation recording
EJKERNEL_OPS_STAMP: Controls profiling metadata format:
- "hash": Use operation hash for labeling (default)
- "json": Use full JSON payload for labeling
- "none": Disable profiling metadata
Example Usage:
>>> cache = ConfigCache()
>>> selector = ConfigSelectorChain(cache)
>>> executor = Executor(selector)
>>>
>>>
>>> result = executor(my_kernel, input_data)
>>>
>>>
>>> compiled_fn = executor.compile(my_kernel, example_input)
>>> result = compiled_fn(actual_input)
"""
from __future__ import annotations
import dataclasses
import os
from collections.abc import Callable
from enum import Enum
from typing import Generic, Literal, Protocol
import jax
import jax.numpy as jnp
import jax.sharding
import jax.tree_util as jtu
from ...kernels._registry import Backend, Platform, kernel_registry
from ..config.cache import _cache_overlay
from ..core import Invocation, Kernel, _get_platform_method, _has_custom_vjp
from ..core.types import Cfg, Out
from ..utils.fingerprint import abstractify, device_fingerprint, get_device_platform, stable_json
[docs]class ConfigChooser(Protocol):
"""Protocol for configuration selection strategies.
Defines the interface that configuration selection strategies must implement.
The primary implementer is ConfigSelectorChain, which provides a sophisticated
multi-tier selection system with caching and autotuning.
Methods:
choose: Select optimal configuration for the given invocation and kernel
"""
[docs] def choose(self, inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) -> Cfg:
"""Select optimal configuration for the given invocation.
Args:
inv: Invocation object containing arguments and metadata
kernel: Kernel implementation requiring configuration
Returns:
Configuration object suitable for the kernel and invocation
"""
...
[docs]class Executor(Generic[Cfg, Out]):
"""Main execution engine for kernels with automatic configuration selection.
The Executor coordinates the entire execution process:
1. Preprocess arguments via kernel.prepare()
2. Select configuration via the ConfigChooser
3. Handle custom VJP if implemented by the kernel
4. Add profiling metadata if requested
5. Execute the kernel with the chosen configuration
Supports both regular operations and custom gradient implementations.
Attributes:
chooser: Configuration selection strategy (typically ConfigSelectorChain)
stamp_prefix: Prefix for profiling metadata labels
"""
def __init__(self, chooser: ConfigChooser, stamp_prefix: str = "ejkernel_ops"):
"""Initialize executor with configuration chooser and profiling settings.
Args:
chooser: Configuration selection strategy (typically ConfigSelectorChain)
stamp_prefix: Prefix for profiling metadata labels in compiled code
"""
self.chooser = chooser
self.stamp_prefix = stamp_prefix
@staticmethod
def _platform_value(val) -> str | None:
"""Convert a platform value to a lowercase string representation.
Normalizes platform identifiers from various formats (Enum, string, etc.)
into a consistent lowercase string for comparison and matching.
Args:
val: Platform value to normalize. Can be None, an Enum member,
or any object convertible to string.
Returns:
Lowercase string representation of the platform, or None if val is None.
"""
if val is None:
return None
if isinstance(val, Enum):
return str(val.value).lower()
return str(val).lower()
@staticmethod
def _is_nvidia_gpu() -> bool:
"""Detect whether the current GPU device is an NVIDIA GPU.
Checks the JAX backend platform version and device properties to
determine if the system has an NVIDIA (CUDA-capable) GPU. Distinguishes
NVIDIA GPUs from AMD (ROCm) GPUs by examining platform version strings
and device kind identifiers.
Returns:
True if an NVIDIA GPU is detected, False otherwise (including AMD
GPUs, no GPUs, or detection failure).
Note:
First checks the XLA backend platform version for 'cuda' or 'rocm'
strings. If inconclusive, falls back to inspecting individual device
kind attributes for NVIDIA-specific keywords (nvidia, tesla, geforce,
rtx, quadro).
"""
try:
from jax.lib import xla_bridge
platform_version = xla_bridge.get_backend().platform_version
if isinstance(platform_version, str):
pv = platform_version.lower()
if "cuda" in pv:
return True
if "rocm" in pv:
return False
except Exception:
pass
try:
devices = jax.devices("gpu")
except Exception:
devices = []
if not devices:
try:
devices = jax.devices()
except Exception:
devices = []
for dev in devices:
kind = (getattr(dev, "device_kind", "") or "").lower()
if any(token in kind for token in ("nvidia", "tesla", "geforce", "rtx", "quadro")):
return True
return False
@staticmethod
def _has_cuda_impl(algorithm: str) -> bool:
"""Check if a native CUDA implementation exists for the given algorithm.
Queries the kernel registry to determine whether a CUDA-platform
implementation is registered for the specified algorithm name.
Args:
algorithm: Algorithm/kernel name to look up (typically kernel.op_id).
Returns:
True if a CUDA implementation exists in the registry, False if not
found or if the registry query fails.
"""
try:
specs = kernel_registry.list_implementations(algorithm)
except Exception:
return False
return any(spec.platform == Platform.CUDA and spec.backend in (Backend.GPU, Backend.ANY) for spec in specs)
@staticmethod
def _has_cute_impl(algorithm: str) -> bool:
"""Check if a CuTe DSL implementation exists for the given algorithm."""
try:
specs = kernel_registry.list_implementations(algorithm)
except Exception:
return False
return any(spec.platform == Platform.CUTE and spec.backend in (Backend.GPU, Backend.ANY) for spec in specs)
def _prefer_cuda_cfg(self, cfg: Cfg, kernel: Kernel[Cfg, Out], inv: Invocation[Cfg, Out]) -> Cfg:
"""Upgrade configuration to prefer CUTE/CUDA when conditions are met.
Automatically switches the platform field in a configuration from 'auto'
or 'triton' to 'cute'/'cuda' when all of the following conditions are satisfied:
- The current configuration platform is 'auto' or 'triton'
- No explicit override configuration was provided
- No explicit platform was specified in invocation kwargs
- The current device is a GPU
- The GPU is an NVIDIA GPU (not AMD ROCm)
- A CuTe or CUDA implementation exists in the kernel registry
This enables transparent use of optimized CUTE/CUDA kernels when
available, while gracefully falling back to Triton implementations
on unsupported platforms.
Args:
cfg: Current configuration object (may have 'platform' and/or
'backend' attributes).
kernel: Kernel instance being executed.
inv: Current invocation with arguments and metadata.
Returns:
Modified configuration with platform set to 'cuda' and backend
set to 'gpu' if upgrade conditions are met, otherwise the
original configuration unchanged.
"""
platform_val = self._platform_value(getattr(cfg, "platform", None))
if platform_val is None or platform_val in ("cuda", "cute"):
return cfg
if platform_val not in ("auto", "triton"):
return cfg
if inv.override_cfg is not None:
return cfg
explicit_platform = None
if isinstance(inv.kwargs, dict):
explicit_platform = self._platform_value(inv.kwargs.get("platform", None))
if explicit_platform not in (None, "auto"):
return cfg
if get_device_platform() != "gpu":
return cfg
if not self._is_nvidia_gpu():
return cfg
has_cute = self._has_cute_impl(kernel.op_id)
has_cuda = self._has_cuda_impl(kernel.op_id)
if not has_cute and not has_cuda:
return cfg
if dataclasses.is_dataclass(cfg):
fields = {field.name for field in dataclasses.fields(cfg)}
updates = {}
if "platform" in fields:
updates["platform"] = "cute" if has_cute else "cuda"
if "backend" in fields:
updates["backend"] = "gpu"
if updates:
try:
return dataclasses.replace(cfg, **updates)
except Exception:
pass
try:
if hasattr(cfg, "platform"):
cfg.platform = "cute" if has_cute else "cuda"
if hasattr(cfg, "backend"):
cfg.backend = "gpu"
except Exception:
return cfg
return cfg
def _stamp_hash(self, kernel, inv, fn, cfg):
"""Add hash-based profiling metadata to function.
Creates a compact label using operation ID and call signature hash
for performance profiling and debugging.
Args:
kernel: Kernel being executed
inv: Invocation object
fn: Function to wrap with profiling metadata
cfg: Configuration being used
Returns:
Function wrapped with hash-based profiling label
"""
call_key = inv.make_key(kernel.key_builder)
op_id_v = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}"
label = f"{self.stamp_prefix}#{op_id_v}:{call_key}"
return self._stamp(label, fn)
def _stamp_json(self, kernel, inv, fn, cfg):
"""Add JSON-based profiling metadata to function.
Creates detailed profiling metadata including full operation context,
arguments, and configuration for comprehensive debugging.
Args:
kernel: Kernel being executed
inv: Invocation object
fn: Function to wrap with profiling metadata
cfg: Configuration being used
Returns:
Function wrapped with JSON profiling metadata
Note:
This mode provides more detailed information but may impact
performance due to larger metadata payloads.
"""
op_id_v = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}"
payload = stable_json(
dict(
op_id=op_id_v,
args=abstractify(inv.args),
kwargs=abstractify(dict(inv.kwargs)),
cfg=cfg,
)
)
def wrapped(*a, **k):
"""Execute the function within a JAX named scope containing the JSON payload."""
with jax.named_scope(f"{self.stamp_prefix}:{payload}"):
return fn(*a, **k)
return wrapped
def _stamp(self, name: str, fn: Callable) -> Callable:
"""Add profiling metadata to function using JAX naming primitives.
Uses JAX's named_call if available, otherwise falls back to named_scope
for adding operation labels to compiled code.
Args:
name: Label to attach to the operation
fn: Function to wrap with profiling metadata
Returns:
Function wrapped with profiling label
"""
if hasattr(jax, "named_call"):
return jax.named_call(fn, name=name)
def wrapped(*a, **k):
"""Execute the function within a JAX named scope for profiling."""
with jax.named_scope(name):
return fn(*a, **k)
return wrapped
def __call__(
self,
kernel: Kernel[Cfg, Out],
*args,
cfg: Cfg | None = None,
stamp: bool = True,
method: Literal["shard_map"] | None = None,
mesh: jax.sharding.Mesh | None = None,
in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None,
out_specs: jax.sharding.PartitionSpec | None = None,
check_vma: bool = False,
**kwargs,
) -> Out:
"""Execute kernel with automatic configuration selection and management.
This is the main execution method that orchestrates the complete execution
pipeline including preprocessing, configuration selection, custom gradients,
profiling, and invocation recording.
Args:
kernel: Kernel implementation to execute
*args: Positional arguments for the kernel
cfg: Optional configuration override (bypasses selection if provided)
stamp: Whether to add profiling metadata to the operation
method: Execution method - "shard_map" for distributed execution
mesh: JAX device mesh for shard_map (required if method="shard_map")
in_specs: Input partition specs for shard_map (required if method="shard_map")
out_specs: Output partition spec for shard_map (required if method="shard_map")
check_vma: Whether to check replication for shard_map
**kwargs: Keyword arguments for the kernel
Returns:
Result of kernel execution with optimal configuration
Note:
This method handles both regular operations and kernels with custom
VJP implementations. Custom gradients are automatically detected and
properly integrated with JAX's differentiation system.
When method="shard_map", the execution will be wrapped with shard_map
for distributed computation across the specified mesh.
"""
if "_cfg" in kwargs:
cfg = kwargs.pop("_cfg")
if method == "shard_map":
if mesh is None:
raise ValueError("mesh must be provided when method='shard_map'")
if in_specs is None:
raise ValueError("in_specs must be provided when method='shard_map'")
if out_specs is None:
raise ValueError("out_specs must be provided when method='shard_map'")
args2, kwargs2 = kernel.prepare(*args, **kwargs)
inv = Invocation(
op_id=kernel.op_id,
args=args2,
kwargs=kwargs2,
override_cfg=cfg,
stamp=stamp,
method=method,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
policy = getattr(self.chooser, "policy", None)
if policy is not None and getattr(policy, "cache_miss_fallback", "heuristics") == "heuristics":
chosen = self._choose_heuristics_only(inv, kernel)
else:
chosen = self.chooser.choose(inv, kernel)
chosen = self._prefer_cuda_cfg(chosen, kernel, inv)
platform = get_device_platform()
context = "shard_map" if method == "shard_map" else None
if _has_custom_vjp(kernel, platform, context):
fwd_method = (
_get_platform_method(kernel, "fwd_with_residuals", platform, context) or kernel.fwd_with_residuals
)
vjp_method = _get_platform_method(kernel, "vjp", platform, context) or kernel.vjp
full_leaves, treedef = jtu.tree_flatten((args2, kwargs2))
is_arr = [isinstance(x, jax.Array) for x in full_leaves]
const_leaves = [None if m else x for m, x in zip(is_arr, full_leaves, strict=False)]
def _restore_args_kwargs(array_leaves):
"""Rebuild (args, kwargs) by merging dynamic array leaves into closed constants."""
it = iter(array_leaves)
merged = [next(it) if m else v for m, v in zip(is_arr, const_leaves, strict=False)]
return jtu.tree_unflatten(treedef, merged)
def fwd_arrays(*array_leaves):
"""Forward rule: takes only array leaves, rebuilds args/kwargs inside."""
(a, k) = _restore_args_kwargs(array_leaves)
y, res = fwd_method(*a, cfg=chosen, **k)
return y, (tuple(array_leaves), res)
def bwd_arrays(payload, dy):
"""Backward rule: rebuild args/kwargs, call kernel.vjp, and map grads to array inputs."""
array_leaves, res = payload
(a, k) = _restore_args_kwargs(array_leaves)
grads = vjp_method(res, dy, *a, cfg=chosen, **k)
if isinstance(grads, dict):
raise TypeError("kernel.vjp must return a tuple of grads for positional args.")
grads = tuple(grads)
if len(grads) != len(a):
raise TypeError(
f"kernel.vjp must return one grad per positional arg; got {len(grads)} for {len(a)} args."
)
def align_arg_grad(x, g):
"""Align gradient structure with argument structure, using None for missing grads."""
if g is None:
return jtu.tree_map(lambda t: None, x)
return jtu.tree_map(lambda _t, gg: gg, x, g)
aligned_args_grads = tuple(align_arg_grad(x, g) for x, g in zip(a, grads, strict=False))
zeros_kwargs = {
name: jtu.tree_map(lambda t: jnp.zeros_like(t) if isinstance(t, jax.Array) else None, val)
for name, val in k.items()
}
full_grads = (aligned_args_grads, zeros_kwargs)
flat_grads, _ = jtu.tree_flatten(full_grads)
grad_out = []
itg = iter(flat_grads)
for m in is_arr:
gleaf = next(itg)
if m:
if gleaf is None:
gleaf = 0.0
grad_out.append(gleaf)
return tuple(grad_out)
def primal_only_arrays(*array_inputs):
"""Compute forward pass output only, discarding residuals for custom VJP."""
return fwd_arrays(*array_inputs)[0]
g = jax.custom_vjp(primal_only_arrays)
g.defvjp(fwd_arrays, bwd_arrays)
def fn(*a, **k):
"""Extract array leaves from args/kwargs and route through the custom VJP wrapper."""
flat_call, _ = jtu.tree_flatten((a, k))
array_in = [x for x, m in zip(flat_call, is_arr, strict=False) if m]
return g(*array_in)
else:
run_method = _get_platform_method(kernel, "run", platform, context) or kernel.run
def fn(*a, **k):
"""Execute the kernel run method with the pre-selected configuration."""
return run_method(*a, cfg=chosen, **k)
if os.getenv("EJKERNEL_OPS_RECORD", "0") == "1":
try:
from ..registry import record_invocation
call_key = inv.make_key(kernel.key_builder)
op_id_v = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}"
record_invocation(device_fingerprint(), op_id_v, call_key, kernel, args2, kwargs2)
except Exception:
pass
if stamp:
mode = os.getenv("EJKERNEL_OPS_STAMP", "none").lower()
if mode == "json":
fn = self._stamp_json(kernel, inv, fn, chosen)
elif mode == "hash":
fn = self._stamp_hash(kernel, inv, fn, chosen)
elif mode == "none":
pass
if method == "shard_map":
if not hasattr(kernel, "create_shard_map_wrapper"):
raise AttributeError(f"Kernel {kernel.op_id} does not implement create_shard_map_wrapper")
callback = None
eagers = kernel.create_shard_map_wrapper(
*args2,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
cfg=chosen,
**kwargs2,
)
if len(eagers) == 2:
shard_map_fn, call_args = eagers
elif len(eagers) == 3:
shard_map_fn, call_args, callback = eagers
outs = shard_map_fn(*call_args)
if callback is not None:
outs = callback(outs, cfg=chosen)
return outs
return fn(*args2, **kwargs2)
[docs] def choose_config(self, kernel: Kernel[Cfg, Out], *args, cfg: Cfg | None = None, **kwargs) -> Cfg:
"""Select configuration for kernel without executing it.
Useful for inspecting what configuration would be chosen for given
arguments, or for pre-selecting configurations for compilation.
Args:
kernel: Kernel implementation requiring configuration
*args: Positional arguments for the kernel
cfg: Optional configuration override
**kwargs: Keyword arguments for the kernel
Returns:
Configuration that would be selected for this invocation
"""
if "_cfg" in kwargs:
cfg = kwargs.pop("_cfg")
args2, kwargs2 = kernel.prepare(*args, **kwargs)
method = kwargs2.pop("method", None)
mesh = kwargs2.pop("mesh", None)
in_specs = kwargs2.pop("in_specs", None)
out_specs = kwargs2.pop("out_specs", None)
check_vma = kwargs2.pop("check_vma", False)
inv = Invocation(
op_id=kernel.op_id,
args=args2,
kwargs=kwargs2,
override_cfg=cfg,
stamp=False,
method=method,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
policy = getattr(self.chooser, "policy", None)
if policy is not None and getattr(policy, "cache_miss_fallback", "heuristics") == "heuristics":
cfg = self._choose_heuristics_only(inv, kernel)
else:
cfg = self.chooser.choose(inv, kernel)
return self._prefer_cuda_cfg(cfg, kernel, inv)
def _choose_heuristics_only(self, inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) -> Cfg:
"""Select configuration using fast heuristics path without autotuning.
This method provides a fast configuration selection path that bypasses
the full autotuning system. It checks caches in the following order:
1. Cache overlay (temporary overrides)
2. In-memory cache
3. Persistent cache
4. Kernel heuristics (fallback)
This is used when the policy is set to use heuristics-only on cache miss,
avoiding expensive autotuning operations in production scenarios.
Args:
inv: Invocation object containing operation arguments and metadata
kernel: Kernel implementation requiring configuration
Returns:
Configuration selected from cache or computed via heuristics
Note:
Unlike the full selector.choose path, this method never triggers
autotuning. Selected configurations are written back to caches
for future use.
"""
dev = device_fingerprint()
op_id_v = f"{kernel.op_id}@v{getattr(kernel, 'version', '0')}"
call_key = inv.make_key(kernel.key_builder)
for overlay in reversed(_cache_overlay.get()):
if (cfg := overlay.get((dev, op_id_v, call_key))) is not None:
return cfg
if (cfg := self.chooser.cache.get(dev, op_id_v, call_key)) is not None:
return cfg
if self.chooser.persistent is not None:
if (cfg := self.chooser.persistent.get(dev, op_id_v, call_key)) is not None:
self.chooser.cache.put(dev, op_id_v, call_key, cfg)
return cfg
platform = get_device_platform()
context = "shard_map" if getattr(inv, "method", None) == "shard_map" else None
heuristic_cfg_method = _get_platform_method(kernel, "heuristic_cfg", platform, context) or kernel.heuristic_cfg
cfg = heuristic_cfg_method(inv)
self.chooser.cache.put(dev, op_id_v, call_key, cfg)
if self.chooser.persistent is not None and self.chooser.persist_autotune:
self.chooser.persistent.put(dev, op_id_v, call_key, cfg)
return cfg
[docs] def compile(self, kernel: Kernel[Cfg, Out], *example_args, cfg: Cfg | None = None, **example_kwargs):
"""Compile kernel with pre-selected configuration.
Selects optimal configuration based on example arguments, then returns
a JAX-compiled function that uses that configuration for all subsequent
calls. This avoids configuration selection overhead during execution.
Args:
kernel: Kernel implementation to compile
*example_args: Example positional arguments for configuration selection
cfg: Optional configuration override
**example_kwargs: Example keyword arguments for configuration selection
Returns:
JAX-compiled function with pre-selected configuration
Example:
>>> compiled_matmul = executor.compile(matmul_kernel, x_example, y_example)
>>>
>>> result = compiled_matmul(x_actual, y_actual)
"""
if "_cfg" in example_kwargs:
cfg = example_kwargs.pop("_cfg")
chosen = self.choose_config(kernel, *example_args, cfg=cfg, **example_kwargs)
def run(*a, **k):
"""Execute the kernel with the pre-selected optimal configuration."""
return kernel.run(*a, cfg=chosen, **k)
return jax.jit(run)