ejkernel.kernels._registry#

Kernel registry system for managing multi-platform implementations.

This module provides the core registry infrastructure for ejkernel’s multi-platform kernel dispatch system. It enables registration and lookup of kernel implementations across different platforms (Triton, Pallas, CUDA, XLA) and backends (GPU, TPU, CPU).

Key Components:
  • Platform: Enum for implementation platforms (TRITON, PALLAS, CUDA, CUTE, XLA)

  • Backend: Enum for hardware backends (GPU, TPU, CPU, ANY)

  • KernelSpec: Dataclass describing a registered kernel implementation

  • KernelRegistry: Registry class for managing kernel implementations

  • kernel_registry: Global singleton registry instance

Usage:
Registration:

@kernel_registry.register(“flash_attention”, Platform.TRITON, Backend.GPU) def flash_attention_triton(q, k, v): …

Lookup:

impl = kernel_registry.get(“flash_attention”, platform=”triton”, backend=”gpu”) result = impl(q, k, v)

Priority System:

Multiple implementations of the same algorithm can coexist with different priorities. When looking up a kernel, the highest priority match is returned. This enables optimized implementations to take precedence over fallbacks.

Signature Validation:

The validate_signatures() method ensures all implementations of an algorithm have compatible signatures, catching registration errors early.

Example

>>> from ejkernel.kernels._registry import kernel_registry, Platform, Backend
>>>
>>> # List all registered algorithms
>>> algorithms = kernel_registry.list_algorithms()
>>>
>>> # Get implementations for an algorithm
>>> impls = kernel_registry.list_implementations("flash_attention")
>>>
>>> # Get best implementation for current platform
>>> impl = kernel_registry.get("flash_attention", platform=Platform.XLA)
class ejkernel.kernels._registry.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: StrEnum

Target hardware backends for kernel execution.

Used to tag kernel implementations with the hardware they target. During lookup, Backend.ANY acts as a wildcard that matches every backend query, making it suitable for platform-agnostic implementations.

GPU#

NVIDIA (or compatible) GPU backend.

TPU#

Google TPU backend.

CPU#

CPU-only backend.

ANY#

Wildcard backend matching any hardware target.

ANY = 'any'#
CPU = 'cpu'#
GPU = 'gpu'#
TPU = 'tpu'#
ejkernel.kernels._registry.F#

TypeVar bound to Callable, used for preserving function signatures in decorators.

alias of TypeVar(‘F’, bound=Callable)

class ejkernel.kernels._registry.KernelRegistry[source]#

Bases: object

Central registry for managing kernel implementations across platforms and backends.

KernelRegistry is the backbone of ejkernel’s multi-platform dispatch system. It stores KernelSpec objects keyed by algorithm name and provides decorator-based registration, priority-aware lookup, and cross-implementation signature validation.

The typical usage pattern is:

  1. Register implementations via the register decorator.

  2. Look up the best implementation for a given algorithm / platform / backend combination via get.

  3. (Optional) Validate that all implementations of an algorithm share a compatible signature via validate_signatures.

_registry#

Internal mapping from lower-cased algorithm names to lists of KernelSpec objects sorted by descending priority.

Example

>>> registry = KernelRegistry()
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU)
... def flash_attention_triton(q, k, v): ...
>>>
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
get(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'cute', 'xla', 'auto']]] = None, backend: Optional[Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']]] = None) Callable[source]#

Retrieve the best matching kernel implementation.

Searches for implementations matching the specified algorithm, platform, and backend. Returns the highest-priority match among all candidates.

Matching rules:
  • If platform is None, any platform matches.

  • If backend is None, any backend matches.

  • Backend.ANY implementations match every backend query.

  • When platform is Platform.XLA and no direct match is found, a fallback lookup with Backend.ANY is attempted.

  • When backend is Backend.ANY and no match is found, a fallback lookup using jax.default_backend() is attempted.

Parameters
  • algorithm – Algorithm name to look up (case-insensitive).

  • platform – Optional platform filter. Accepts a Platform enum member, its string value, or "auto".

  • backend – Optional backend filter. Accepts a Backend enum member or its string value.

Returns

The matching kernel implementation callable.

Raises

ValueError – If no matching implementation is found after all fallback attempts.

Example

>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
>>> result = impl(q, k, v)
list_algorithms() list[str][source]#

List all registered algorithm names.

Returns

A sorted list of lower-cased algorithm name strings currently present in the registry.

list_implementations(algorithm: str) list[ejkernel.kernels._registry.KernelSpec][source]#

List all registered implementations for a given algorithm.

Parameters

algorithm – Algorithm name to query (case-insensitive).

Returns

A shallow copy of the KernelSpec list for the algorithm, sorted by priority in descending order. Returns an empty list if the algorithm has not been registered.

register(algorithm: str, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'cute', 'xla']], backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']], priority: int = 0) Callable[[F], F][source]#

Decorator to register a kernel implementation.

Wraps func in a validation layer that checks for unsupported parameters before dispatch and converts certain runtime exceptions into EjkernelRuntimeError. The wrapped function is stored in a KernelSpec and appended to the internal registry under algorithm (case-insensitive).

Parameters
  • algorithm – Name of the algorithm (e.g., 'flash_attention').

  • platform – Implementation platform. Accepts a Platform enum member or its string value.

  • backend – Target hardware backend. Accepts a Backend enum member or its string value.

  • priority – Selection priority (default: 0). Higher values are preferred during lookup.

Returns

A decorator that registers the kernel and returns the wrapped callable (preserving the original signature).

Example

>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10)
... def flash_attention_impl(q, k, v):
...     return compute_attention(q, k, v)
validate_signatures(algorithm: str | None, verbose: bool = False) bool[source]#

Validate that all implementations of an algorithm have compatible signatures.

Uses the first registered implementation (highest priority) as the reference and compares every subsequent implementation against it. The following properties are checked for each parameter:

  • Name – parameter names must match in order.

  • Kind – positional-only, keyword-only, etc. must agree.

  • Default value – default values must be equal.

  • Type annotation – annotations are compared after normalization via _types_are_equivalent.

Any mismatch emits a UserWarning describing the discrepancy.

When algorithm is None, all registered algorithms are validated in turn (the return value is None in this case).

Parameters
  • algorithm – Algorithm name to validate (case-insensitive), or None to validate every registered algorithm.

  • verbose – If True, print detailed parameter information for every implementation before running comparisons.

Returns

True if all signatures match, False otherwise. Returns None implicitly when algorithm is None.

Raises

ValueError – If the specified algorithm has not been registered.

class ejkernel.kernels._registry.KernelSpec(platform: Platform, backend: Backend, algorithm: str, implementation: Callable, priority: int = 0)[source]#

Bases: object

Immutable specification describing a single registered kernel implementation.

Each KernelSpec binds a concrete callable to its algorithm name, target platform, and hardware backend. The priority field governs selection order when multiple implementations match a lookup query – higher values are preferred.

platform#

The implementation platform (e.g., Platform.TRITON).

Type

ejkernel.kernels._registry.Platform

backend#

Target hardware backend (e.g., Backend.GPU).

Type

ejkernel.kernels._registry.Backend

algorithm#

Canonical algorithm name (e.g., 'flash_attention').

Type

str

implementation#

The wrapped kernel callable.

Type

collections.abc.Callable

priority#

Selection priority; higher values are preferred during lookup. Defaults to 0.

Type

int

algorithm: str#
backend: Backend#
implementation: Callable#
platform: Platform#
priority: int = 0#
class ejkernel.kernels._registry.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: StrEnum

Supported kernel implementation platforms.

Each member identifies a compilation/execution framework used to implement a kernel.

TRITON#

OpenAI Triton GPU kernels.

PALLAS#

JAX Pallas kernels (supports both GPU and TPU).

CUDA#

Native CUDA C/C++ kernels compiled ahead-of-time.

CUTE#

CUTLASS CuTe DSL kernels.

XLA#

XLA HLO-based implementations using JAX primitives.

CUDA = 'cuda'#
CUTE = 'cute'#
PALLAS = 'pallas'#
TRITON = 'triton'#
XLA = 'xla'#
ejkernel.kernels._registry.kernel_registry = <ejkernel.kernels._registry.KernelRegistry object>#

Global singleton KernelRegistry instance.

All built-in kernel implementations register themselves against this instance at import time. User code should typically interact with this object rather than creating a new KernelRegistry.