# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kernel registry system for managing multi-platform implementations.
This module provides the core registry infrastructure for ejkernel's
multi-platform kernel dispatch system. It enables registration and
lookup of kernel implementations across different platforms (Triton,
Pallas, CUDA, XLA) and backends (GPU, TPU, CPU).
Key Components:
- Platform: Enum for implementation platforms (TRITON, PALLAS, CUDA, CUTE, XLA)
- Backend: Enum for hardware backends (GPU, TPU, CPU, ANY)
- KernelSpec: Dataclass describing a registered kernel implementation
- KernelRegistry: Registry class for managing kernel implementations
- kernel_registry: Global singleton registry instance
Usage:
Registration:
@kernel_registry.register("flash_attention", Platform.TRITON, Backend.GPU)
def flash_attention_triton(q, k, v): ...
Lookup:
impl = kernel_registry.get("flash_attention", platform="triton", backend="gpu")
result = impl(q, k, v)
Priority System:
Multiple implementations of the same algorithm can coexist with different
priorities. When looking up a kernel, the highest priority match is returned.
This enables optimized implementations to take precedence over fallbacks.
Signature Validation:
The validate_signatures() method ensures all implementations of an algorithm
have compatible signatures, catching registration errors early.
Example:
>>> from ejkernel.kernels._registry import kernel_registry, Platform, Backend
>>>
>>> # List all registered algorithms
>>> algorithms = kernel_registry.list_algorithms()
>>>
>>> # Get implementations for an algorithm
>>> impls = kernel_registry.list_implementations("flash_attention")
>>>
>>> # Get best implementation for current platform
>>> impl = kernel_registry.get("flash_attention", platform=Platform.XLA)
"""
from __future__ import annotations
import ast
import functools
import inspect
import re
import textwrap
import warnings
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Literal, TypeVar, overload
import jax
from ejkernel.errors import EjkernelRuntimeError
F = TypeVar("F", bound=Callable)
"""TypeVar bound to Callable, used for preserving function signatures in decorators."""
_IGNORED_PARAM_CACHE: dict[Callable[..., Any], set[str]] = {}
"""Cache mapping functions to the set of parameter names explicitly deleted within their body."""
_TUNING_PARAM_NAMES: set[str] = {
# Kernel tuning/autotuning parameter names that are excluded from unsupported-parameter
# checks. These parameters control low-level execution behaviour (block sizes, warp
# counts, etc.) and may be silently ignored by implementations that do not use them.
"block_m",
"block_n",
"block_k",
"block_q",
"block_kv",
"block_size",
"num_warps",
"num_stages",
"num_ctas",
"num_waves",
"num_sms",
"num_splits",
"split_k",
"use_bf16",
"seq_threshold_3d",
"num_par_softmax_segments",
"max_warps",
"num_queries_per_block",
"num_kv_pages_per_block",
"num_kv_splits",
"fwd_params",
"bwd_params",
}
def _get_ignored_params(func: Callable[..., Any]) -> set[str]:
"""Detect parameter names explicitly deleted inside a function body.
Parses the source code of *func* using the ``ast`` module and collects
all names that appear in ``del`` statements. These names typically
represent parameters the implementation chooses to discard, signaling
that the corresponding feature is unsupported.
Results are cached per function in ``_IGNORED_PARAM_CACHE`` to avoid
repeated source parsing.
Args:
func: The callable whose source code will be inspected.
Returns:
A set of parameter names found in ``del`` statements within
*func*. Returns an empty set if the source cannot be retrieved
or parsed.
"""
cached = _IGNORED_PARAM_CACHE.get(func)
if cached is not None:
return cached
try:
source = inspect.getsource(func)
except OSError:
_IGNORED_PARAM_CACHE[func] = set()
return _IGNORED_PARAM_CACHE[func]
try:
tree = ast.parse(textwrap.dedent(source))
except SyntaxError:
_IGNORED_PARAM_CACHE[func] = set()
return _IGNORED_PARAM_CACHE[func]
ignored: set[str] = set()
class _DelVisitor(ast.NodeVisitor):
"""AST visitor that collects names from ``del`` statements."""
def visit_Delete(self, node: ast.Delete) -> None:
"""Record every ``ast.Name`` target in a ``del`` statement."""
for target in node.targets:
if isinstance(target, ast.Name):
ignored.add(target.id)
self.generic_visit(node)
_DelVisitor().visit(tree)
_IGNORED_PARAM_CACHE[func] = ignored
return ignored
def _is_non_default(value: Any, default: Any) -> bool:
"""Determine whether a supplied argument value differs from its default.
Uses identity checks for ``None`` and ``inspect._empty``, equality
checks for common scalar types (``bool``, ``int``, ``float``, ``str``),
and falls back to identity comparison for all other types.
Args:
value: The argument value that was actually passed by the caller.
default: The parameter's default value from its signature. If the
parameter has no default, this should be ``inspect._empty``.
Returns:
``True`` if *value* is considered different from *default*,
``False`` otherwise.
"""
if default is inspect._empty:
return True
if default is None:
return value is not None
if isinstance(default, (bool, int, float, str)):
return value != default
return value is not default
def _collect_unsupported_reasons(
func: Callable[..., Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> list[str]:
"""Collect human-readable reasons why a kernel call is unsupported.
This function checks two sources of unsupported-feature information:
1. **Ignored parameters** -- Parameters whose names appear in ``del``
statements within *func* (detected by ``_get_ignored_params``). If the
caller supplied a non-default value for such a parameter and it is not
a tuning parameter, a reason string is generated.
2. **Explicit unsupported hook** -- If *func* carries an
``__ejkernel_unsupported__`` attribute, it is invoked (or iterated)
to produce additional reason strings.
Args:
func: The kernel implementation function to inspect.
args: Positional arguments passed to the kernel call.
kwargs: Keyword arguments passed to the kernel call.
Returns:
A list of human-readable reason strings. An empty list means the
call is fully supported.
"""
reasons: list[str] = []
sig = inspect.signature(func)
bound = sig.bind_partial(*args, **kwargs)
ignored = _get_ignored_params(func)
for name in ignored:
if name in _TUNING_PARAM_NAMES:
continue
if name in bound.arguments and _is_non_default(bound.arguments[name], sig.parameters[name].default):
reasons.append(f"{name} is not supported")
extra = getattr(func, "__ejkernel_unsupported__", None)
if extra is not None:
if callable(extra):
try:
result = extra(**bound.arguments)
except TypeError:
result = extra(bound.arguments)
if isinstance(result, str):
reasons.append(result)
elif isinstance(result, Iterable):
reasons.extend([str(item) for item in result if item])
elif isinstance(extra, str):
reasons.append(extra)
elif isinstance(extra, Iterable):
reasons.extend([str(item) for item in extra if item])
return reasons
def _normalize_type_string(type_annotation: Any) -> str:
"""Normalize type annotation string for comparison.
Handles cases where the same type is imported differently:
- 'jaxtyping.Float' -> 'Float'
- 'ejkernel.ops.utils.datacarrier.FwdParams' -> 'FwdParams'
Args:
type_annotation: The type annotation to normalize
Returns:
Normalized string representation of the type
"""
if type_annotation is inspect._empty:
return "inspect._empty"
type_str = type_annotation if isinstance(type_annotation, str) else str(type_annotation)
type_str = re.sub(r"<class '(.+)'>", r"\1", type_str)
type_str = re.sub(r"\bjaxtyping\.", "", type_str)
# Normalize common JAX array type spellings.
#
# Why: with `from __future__ import annotations`, the annotation may stay as
# `Float[Array, ...]` (a string), while an eagerly-evaluated annotation will
# often render as `jaxtyping.Float[jaxlib._jax.Array, ...]`. These are
# semantically equivalent in this project and shouldn't trigger a mismatch.
type_str = re.sub(r"\b(?:jaxlib\._jax\.Array|jax\.jaxlib\._jax\.Array|jax\.Array)\b", "Array", type_str)
type_str = re.sub(r"\bjax\._src\.lax\.lax\.", "", type_str)
type_str = re.sub(r"\bjax\._src\.typing\.", "", type_str)
type_str = re.sub(r"\bjax\.typing\.", "", type_str)
type_str = re.sub(r"\bjax\.lax\.", "lax.", type_str)
type_str = re.sub(r"\bejkernel\.[\w\.]+\.(\w+)", r"\1", type_str)
if "PrecisionLike" in type_str or "DotAlgorithm" in type_str or "DotAlgorithmPreset" in type_str:
return "PrecisionLike"
if "DTypeLike" in type_str or "SupportsDType" in type_str:
return "DTypeLike"
return type_str
def _types_are_equivalent(type1: Any, type2: Any) -> bool:
"""Check if two type annotations are equivalent.
This handles cases where the same type might be imported differently
in different modules, e.g., 'Float' vs 'jaxtyping.Float'.
Args:
type1: First type annotation
type2: Second type annotation
Returns:
True if types are equivalent, False otherwise
"""
if type1 is inspect._empty and type2 is inspect._empty:
return True
if (type1 is inspect._empty) != (type2 is inspect._empty):
return False
normalized1 = _normalize_type_string(type1)
normalized2 = _normalize_type_string(type2)
return normalized1 == normalized2
[docs]class Backend(StrEnum):
"""Target hardware backends for kernel execution.
Used to tag kernel implementations with the hardware they target.
During lookup, ``Backend.ANY`` acts as a wildcard that matches every
backend query, making it suitable for platform-agnostic implementations.
Attributes:
GPU: NVIDIA (or compatible) GPU backend.
TPU: Google TPU backend.
CPU: CPU-only backend.
ANY: Wildcard backend matching any hardware target.
"""
GPU = "gpu"
TPU = "tpu"
CPU = "cpu"
ANY = "any"
[docs]@dataclass(frozen=True)
class KernelSpec:
"""Immutable specification describing a single registered kernel implementation.
Each ``KernelSpec`` binds a concrete callable to its algorithm name,
target platform, and hardware backend. The ``priority`` field governs
selection order when multiple implementations match a lookup query --
higher values are preferred.
Attributes:
platform: The implementation platform (e.g., ``Platform.TRITON``).
backend: Target hardware backend (e.g., ``Backend.GPU``).
algorithm: Canonical algorithm name (e.g., ``'flash_attention'``).
implementation: The wrapped kernel callable.
priority: Selection priority; higher values are preferred during
lookup. Defaults to ``0``.
"""
platform: Platform
backend: Backend
algorithm: str
implementation: Callable
priority: int = 0
[docs]class KernelRegistry:
"""Central registry for managing kernel implementations across platforms and backends.
``KernelRegistry`` is the backbone of ejkernel's multi-platform dispatch
system. It stores ``KernelSpec`` objects keyed by algorithm name and
provides decorator-based registration, priority-aware lookup, and
cross-implementation signature validation.
The typical usage pattern is:
1. **Register** implementations via the ``register`` decorator.
2. **Look up** the best implementation for a given algorithm / platform /
backend combination via ``get``.
3. (Optional) **Validate** that all implementations of an algorithm share
a compatible signature via ``validate_signatures``.
Attributes:
_registry: Internal mapping from lower-cased algorithm names to
lists of ``KernelSpec`` objects sorted by descending priority.
Example:
>>> registry = KernelRegistry()
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU)
... def flash_attention_triton(q, k, v): ...
>>>
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
"""
def __init__(self) -> None:
"""Initialize an empty kernel registry with no registered algorithms."""
self._registry: dict[str, list[KernelSpec]] = {}
@overload
def register(
self,
algorithm: str,
platform: Platform | Literal["triton", "pallas", "cuda", "cute", "xla"],
backend: Backend | Literal["gpu", "tpu", "cpu", "any"],
priority: int = 0,
) -> Callable[[F], F]: ...
[docs] def register(
self,
algorithm: str,
platform: Platform | Literal["triton", "pallas", "cuda", "cute", "xla"],
backend: Backend | Literal["gpu", "tpu", "cpu", "any"],
priority: int = 0,
) -> Callable[[F], F]:
"""Decorator to register a kernel implementation.
Wraps *func* in a validation layer that checks for unsupported
parameters before dispatch and converts certain runtime exceptions
into ``EjkernelRuntimeError``. The wrapped function is stored in a
``KernelSpec`` and appended to the internal registry under
*algorithm* (case-insensitive).
Args:
algorithm: Name of the algorithm (e.g., ``'flash_attention'``).
platform: Implementation platform. Accepts a ``Platform``
enum member or its string value.
backend: Target hardware backend. Accepts a ``Backend``
enum member or its string value.
priority: Selection priority (default: ``0``). Higher values
are preferred during lookup.
Returns:
A decorator that registers the kernel and returns the
wrapped callable (preserving the original signature).
Example:
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10)
... def flash_attention_impl(q, k, v):
... return compute_attention(q, k, v)
"""
def decorator(func: F) -> F:
"""Inner decorator that wraps and registers *func*."""
key = algorithm.lower()
if key not in self._registry:
self._registry[key] = []
plat = Platform(platform) if isinstance(platform, str) else platform
back = Backend(backend) if isinstance(backend, str) else backend
@functools.wraps(func)
def _wrapped(*args, **kwargs):
"""Validation wrapper that guards the kernel call.
Before forwarding to the original implementation, this
wrapper checks for unsupported parameter usage and
re-raises certain exceptions as ``EjkernelRuntimeError``.
"""
reasons = _collect_unsupported_reasons(func, args, kwargs)
if reasons:
raise EjkernelRuntimeError(f"{algorithm} (platform={plat.value}): " + "; ".join(reasons))
try:
return func(*args, **kwargs)
except EjkernelRuntimeError:
raise
except (NotImplementedError, ValueError) as exc:
msg = str(exc)
if "not supported" in msg.lower() or "unsupported" in msg.lower():
raise EjkernelRuntimeError(f"{algorithm} (platform={plat.value}): {msg}") from exc
raise
_wrapped.__signature__ = inspect.signature(func)
spec = KernelSpec(
platform=plat,
backend=back,
algorithm=algorithm,
implementation=_wrapped,
priority=priority,
)
self._registry[key].append(spec)
self._registry[key].sort(key=lambda x: x.priority, reverse=True)
return _wrapped # type: ignore[return-value]
return decorator
[docs] def get(
self,
algorithm: str,
platform: Platform | Literal["triton", "pallas", "cuda", "cute", "xla", "auto"] | None = None,
backend: Backend | Literal["gpu", "tpu", "cpu", "any"] | None = None,
) -> Callable:
"""Retrieve the best matching kernel implementation.
Searches for implementations matching the specified algorithm,
platform, and backend. Returns the highest-priority match among
all candidates.
Matching rules:
- If *platform* is ``None``, any platform matches.
- If *backend* is ``None``, any backend matches.
- ``Backend.ANY`` implementations match every backend query.
- When *platform* is ``Platform.XLA`` and no direct match is
found, a fallback lookup with ``Backend.ANY`` is attempted.
- When *backend* is ``Backend.ANY`` and no match is found, a
fallback lookup using ``jax.default_backend()`` is attempted.
Args:
algorithm: Algorithm name to look up (case-insensitive).
platform: Optional platform filter. Accepts a ``Platform``
enum member, its string value, or ``"auto"``.
backend: Optional backend filter. Accepts a ``Backend``
enum member or its string value.
Returns:
The matching kernel implementation callable.
Raises:
ValueError: If no matching implementation is found after all
fallback attempts.
Example:
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
>>> result = impl(q, k, v)
"""
key = algorithm.lower()
if key not in self._registry:
raise ValueError(f"No implementation found for algorithm: {algorithm}")
candidates = self._registry[key]
if isinstance(platform, str):
platform = Platform(platform)
if isinstance(backend, str):
backend = Backend(backend)
for spec in candidates:
if platform is not None and spec.platform != platform:
continue
if backend is not None and spec.backend != backend and spec.backend != Backend.ANY:
continue
return spec.implementation
if platform == Platform.XLA:
return self.get(algorithm=algorithm, platform=platform, backend=Backend.ANY)
if backend == Backend.ANY:
return self.get(algorithm=algorithm, platform=platform, backend=jax.default_backend())
raise ValueError(f"No implementation found for algorithm={algorithm}, platform={platform}, backend={backend}")
[docs] def list_algorithms(self) -> list[str]:
"""List all registered algorithm names.
Returns:
A sorted list of lower-cased algorithm name strings currently
present in the registry.
"""
return sorted(self._registry.keys())
[docs] def list_implementations(self, algorithm: str) -> list[KernelSpec]:
"""List all registered implementations for a given algorithm.
Args:
algorithm: Algorithm name to query (case-insensitive).
Returns:
A shallow copy of the ``KernelSpec`` list for the algorithm,
sorted by priority in descending order. Returns an empty
list if the algorithm has not been registered.
"""
key = algorithm.lower()
return self._registry.get(key, []).copy()
[docs] def validate_signatures(self, algorithm: str | None, verbose: bool = False) -> bool:
"""Validate that all implementations of an algorithm have compatible signatures.
Uses the first registered implementation (highest priority) as the
reference and compares every subsequent implementation against it.
The following properties are checked for each parameter:
- **Name** -- parameter names must match in order.
- **Kind** -- positional-only, keyword-only, etc. must agree.
- **Default value** -- default values must be equal.
- **Type annotation** -- annotations are compared after
normalization via ``_types_are_equivalent``.
Any mismatch emits a ``UserWarning`` describing the discrepancy.
When *algorithm* is ``None``, **all** registered algorithms are
validated in turn (the return value is ``None`` in this case).
Args:
algorithm: Algorithm name to validate (case-insensitive), or
``None`` to validate every registered algorithm.
verbose: If ``True``, print detailed parameter information for
every implementation before running comparisons.
Returns:
``True`` if all signatures match, ``False`` otherwise.
Returns ``None`` implicitly when *algorithm* is ``None``.
Raises:
ValueError: If the specified algorithm has not been registered.
"""
if algorithm is None:
for algo in self.list_algorithms():
self.validate_signatures(algo)
return
key = algorithm.lower()
if key not in self._registry:
raise ValueError(f"No implementation found for algorithm: {algorithm}")
specs = self._registry[key]
if len(specs) < 2:
return True
reference_spec = specs[0]
reference_sig = inspect.signature(reference_spec.implementation)
reference_params = list(reference_sig.parameters.values())
if verbose:
print(f"\n{'=' * 80}")
print(f"Algorithm: {algorithm}")
print(f"{'=' * 80}")
for spec in specs:
sig = inspect.signature(spec.implementation)
print(f"\n{spec.platform}/{spec.backend} (priority={spec.priority}):")
print(f" Signature: {sig}")
for param_name, param in sig.parameters.items():
print(f" {param_name}:")
print(f" kind: {param.kind.name}")
print(f" default: {param.default}")
print(f" annotation: {param.annotation}")
print(f"{'=' * 80}\n")
all_match = True
for spec in specs[1:]:
sig = inspect.signature(spec.implementation)
params = list(sig.parameters.values())
if len(params) != len(reference_params):
warnings.warn(
f"Signature mismatch for algorithm '{algorithm}':\n"
f" Reference ({reference_spec.platform}/{reference_spec.backend}): "
f"{len(reference_params)} parameters\n"
f" Implementation ({spec.platform}/{spec.backend}): {len(params)} parameters",
UserWarning,
stacklevel=2,
)
all_match = False
continue
for ref_param, param in zip(reference_params, params, strict=False):
if ref_param.name != param.name:
warnings.warn(
f"Signature mismatch for algorithm '{algorithm}':\n"
f" Reference ({reference_spec.platform}/{reference_spec.backend}): "
f"parameter '{ref_param.name}'\n"
f" Implementation ({spec.platform}/{spec.backend}): parameter '{param.name}'",
UserWarning,
stacklevel=2,
)
all_match = False
if ref_param.kind != param.kind:
warnings.warn(
f"Signature mismatch for algorithm '{algorithm}' parameter '{ref_param.name}':\n"
f" Reference ({reference_spec.platform}/{reference_spec.backend}): {ref_param.kind.name}\n"
f" Implementation ({spec.platform}/{spec.backend}): {param.kind.name}",
UserWarning,
stacklevel=2,
)
all_match = False
if ref_param.default != param.default:
warnings.warn(
f"Signature mismatch for algorithm '{algorithm}' parameter '{ref_param.name}':\n"
f" Reference ({reference_spec.platform}/{reference_spec.backend}): "
f"default={ref_param.default}\n"
f" Implementation ({spec.platform}/{spec.backend}): default={param.default}",
UserWarning,
stacklevel=2,
)
all_match = False
if not _types_are_equivalent(ref_param.annotation, param.annotation):
warnings.warn(
f"Signature mismatch for algorithm '{algorithm}' parameter '{ref_param.name}':\n"
f" Reference ({reference_spec.platform}/{reference_spec.backend}): "
f"type={ref_param.annotation} = {ref_param.default}\n"
f" Implementation ({spec.platform}/{spec.backend}): type={param.annotation} = {param.default}",
UserWarning,
stacklevel=2,
)
all_match = False
return all_match
kernel_registry = KernelRegistry()
"""Global singleton ``KernelRegistry`` instance.
All built-in kernel implementations register themselves against this
instance at import time. User code should typically interact with this
object rather than creating a new ``KernelRegistry``.
"""