ejkernel.kernels._registry#
Kernel registry system for managing multi-platform implementations.
This module provides the core registry infrastructure for ejkernel’s multi-platform kernel dispatch system. It enables registration and lookup of kernel implementations across different platforms (Triton, Pallas, CUDA, XLA) and backends (GPU, TPU, CPU).
- Key Components:
Platform: Enum for implementation platforms (TRITON, PALLAS, CUDA, CUTE, XLA)
Backend: Enum for hardware backends (GPU, TPU, CPU, ANY)
KernelSpec: Dataclass describing a registered kernel implementation
KernelRegistry: Registry class for managing kernel implementations
kernel_registry: Global singleton registry instance
- Usage:
- Registration:
@kernel_registry.register(“flash_attention”, Platform.TRITON, Backend.GPU) def flash_attention_triton(q, k, v): …
- Lookup:
impl = kernel_registry.get(“flash_attention”, platform=”triton”, backend=”gpu”) result = impl(q, k, v)
- Priority System:
Multiple implementations of the same algorithm can coexist with different priorities. When looking up a kernel, the highest priority match is returned. This enables optimized implementations to take precedence over fallbacks.
- Signature Validation:
The validate_signatures() method ensures all implementations of an algorithm have compatible signatures, catching registration errors early.
Example
>>> from ejkernel.kernels._registry import kernel_registry, Platform, Backend
>>>
>>> # List all registered algorithms
>>> algorithms = kernel_registry.list_algorithms()
>>>
>>> # Get implementations for an algorithm
>>> impls = kernel_registry.list_implementations("flash_attention")
>>>
>>> # Get best implementation for current platform
>>> impl = kernel_registry.get("flash_attention", platform=Platform.XLA)
- class ejkernel.kernels._registry.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
StrEnumTarget hardware backends for kernel execution.
Used to tag kernel implementations with the hardware they target. During lookup,
Backend.ANYacts as a wildcard that matches every backend query, making it suitable for platform-agnostic implementations.- GPU#
NVIDIA (or compatible) GPU backend.
- TPU#
Google TPU backend.
- CPU#
CPU-only backend.
- ANY#
Wildcard backend matching any hardware target.
- ANY = 'any'#
- CPU = 'cpu'#
- GPU = 'gpu'#
- TPU = 'tpu'#
- ejkernel.kernels._registry.F#
TypeVar bound to Callable, used for preserving function signatures in decorators.
alias of TypeVar(‘F’, bound=
Callable)
- class ejkernel.kernels._registry.KernelRegistry[source]#
Bases:
objectCentral registry for managing kernel implementations across platforms and backends.
KernelRegistryis the backbone of ejkernel’s multi-platform dispatch system. It storesKernelSpecobjects keyed by algorithm name and provides decorator-based registration, priority-aware lookup, and cross-implementation signature validation.The typical usage pattern is:
Register implementations via the
registerdecorator.Look up the best implementation for a given algorithm / platform / backend combination via
get.(Optional) Validate that all implementations of an algorithm share a compatible signature via
validate_signatures.
- _registry#
Internal mapping from lower-cased algorithm names to lists of
KernelSpecobjects sorted by descending priority.
Example
>>> registry = KernelRegistry() >>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU) ... def flash_attention_triton(q, k, v): ... >>> >>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
- get(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'cute', 'xla', 'auto']]] = None, backend: Optional[Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']]] = None) Callable[source]#
Retrieve the best matching kernel implementation.
Searches for implementations matching the specified algorithm, platform, and backend. Returns the highest-priority match among all candidates.
- Matching rules:
If platform is
None, any platform matches.If backend is
None, any backend matches.Backend.ANYimplementations match every backend query.When platform is
Platform.XLAand no direct match is found, a fallback lookup withBackend.ANYis attempted.When backend is
Backend.ANYand no match is found, a fallback lookup usingjax.default_backend()is attempted.
- Parameters
algorithm – Algorithm name to look up (case-insensitive).
platform – Optional platform filter. Accepts a
Platformenum member, its string value, or"auto".backend – Optional backend filter. Accepts a
Backendenum member or its string value.
- Returns
The matching kernel implementation callable.
- Raises
ValueError – If no matching implementation is found after all fallback attempts.
Example
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu") >>> result = impl(q, k, v)
- list_algorithms() list[str][source]#
List all registered algorithm names.
- Returns
A sorted list of lower-cased algorithm name strings currently present in the registry.
- list_implementations(algorithm: str) list[ejkernel.kernels._registry.KernelSpec][source]#
List all registered implementations for a given algorithm.
- Parameters
algorithm – Algorithm name to query (case-insensitive).
- Returns
A shallow copy of the
KernelSpeclist for the algorithm, sorted by priority in descending order. Returns an empty list if the algorithm has not been registered.
- register(algorithm: str, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'cute', 'xla']], backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']], priority: int = 0) Callable[[F], F][source]#
Decorator to register a kernel implementation.
Wraps func in a validation layer that checks for unsupported parameters before dispatch and converts certain runtime exceptions into
EjkernelRuntimeError. The wrapped function is stored in aKernelSpecand appended to the internal registry under algorithm (case-insensitive).- Parameters
algorithm – Name of the algorithm (e.g.,
'flash_attention').platform – Implementation platform. Accepts a
Platformenum member or its string value.backend – Target hardware backend. Accepts a
Backendenum member or its string value.priority – Selection priority (default:
0). Higher values are preferred during lookup.
- Returns
A decorator that registers the kernel and returns the wrapped callable (preserving the original signature).
Example
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10) ... def flash_attention_impl(q, k, v): ... return compute_attention(q, k, v)
- validate_signatures(algorithm: str | None, verbose: bool = False) bool[source]#
Validate that all implementations of an algorithm have compatible signatures.
Uses the first registered implementation (highest priority) as the reference and compares every subsequent implementation against it. The following properties are checked for each parameter:
Name – parameter names must match in order.
Kind – positional-only, keyword-only, etc. must agree.
Default value – default values must be equal.
Type annotation – annotations are compared after normalization via
_types_are_equivalent.
Any mismatch emits a
UserWarningdescribing the discrepancy.When algorithm is
None, all registered algorithms are validated in turn (the return value isNonein this case).- Parameters
algorithm – Algorithm name to validate (case-insensitive), or
Noneto validate every registered algorithm.verbose – If
True, print detailed parameter information for every implementation before running comparisons.
- Returns
Trueif all signatures match,Falseotherwise. ReturnsNoneimplicitly when algorithm isNone.- Raises
ValueError – If the specified algorithm has not been registered.
- class ejkernel.kernels._registry.KernelSpec(platform: Platform, backend: Backend, algorithm: str, implementation: Callable, priority: int = 0)[source]#
Bases:
objectImmutable specification describing a single registered kernel implementation.
Each
KernelSpecbinds a concrete callable to its algorithm name, target platform, and hardware backend. Thepriorityfield governs selection order when multiple implementations match a lookup query – higher values are preferred.- platform#
The implementation platform (e.g.,
Platform.TRITON).
- backend#
Target hardware backend (e.g.,
Backend.GPU).
- algorithm#
Canonical algorithm name (e.g.,
'flash_attention').- Type
str
- implementation#
The wrapped kernel callable.
- priority#
Selection priority; higher values are preferred during lookup. Defaults to
0.- Type
int
- algorithm: str#
- priority: int = 0#
- class ejkernel.kernels._registry.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
StrEnumSupported kernel implementation platforms.
Each member identifies a compilation/execution framework used to implement a kernel.
- TRITON#
OpenAI Triton GPU kernels.
- PALLAS#
JAX Pallas kernels (supports both GPU and TPU).
- CUDA#
Native CUDA C/C++ kernels compiled ahead-of-time.
- CUTE#
CUTLASS CuTe DSL kernels.
- XLA#
XLA HLO-based implementations using JAX primitives.
- CUDA = 'cuda'#
- CUTE = 'cute'#
- PALLAS = 'pallas'#
- TRITON = 'triton'#
- XLA = 'xla'#
- ejkernel.kernels._registry.kernel_registry = <ejkernel.kernels._registry.KernelRegistry object>#
Global singleton
KernelRegistryinstance.All built-in kernel implementations register themselves against this instance at import time. User code should typically interact with this object rather than creating a new
KernelRegistry.