ejkernel.kernels._registry#

Kernel registry system for managing multi-platform implementations.

class ejkernel.kernels._registry.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Target hardware backends for kernel execution.

ANY = 'any'#
CPU = 'cpu'#
GPU = 'gpu'#
TPU = 'tpu'#
class ejkernel.kernels._registry.KernelRegistry[source]#

Bases: object

Registry for managing kernel implementations across platforms and backends.

Supports registering multiple implementations of the same algorithm for different platforms and backends, with priority-based selection.

Example

>>> registry = KernelRegistry()
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU)
... def flash_attention_triton(q, k, v): ...
>>>
>>>
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
get(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']]] = None, backend: Optional[Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']]] = None) Callable[source]#

Retrieve the best matching kernel implementation.

Searches for implementations matching the specified algorithm, platform, and backend. Returns the highest priority match. If backend is not specified, any backend will match. Backend.ANY implementations match all backend queries.

Parameters
  • algorithm – Algorithm name to look up

  • platform – Optional platform filter

  • backend – Optional backend filter

Returns

The matching kernel implementation function

Raises

ValueError – If no matching implementation is found

Example

>>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
>>> result = impl(q, k, v)
list_algorithms() list[str][source]#

List all registered algorithm names.

Returns

Sorted list of algorithm names

list_implementations(algorithm: str) list[ejkernel.kernels._registry.KernelSpec][source]#

List all implementations for a given algorithm.

Parameters

algorithm – Algorithm name to query

Returns

List of KernelSpec objects, sorted by priority (descending)

register(algorithm: str, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla']], backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']], priority: int = 0) Callable[[F], F][source]#

Decorator to register a kernel implementation.

Parameters
  • algorithm – Name of the algorithm (e.g., ‘flash_attention’)

  • platform – Implementation platform

  • backend – Target hardware backend

  • priority – Selection priority (default: 0). Higher values are preferred.

Returns

Decorator function that registers the kernel and returns it unchanged

Example

>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10)
... def flash_attention_impl(q, k, v):
...     return compute_attention(q, k, v)
validate_signatures(algorithm: str | None, verbose: bool = False) bool[source]#

Validate that all implementations of an algorithm have matching signatures.

Compares parameter names, order, and defaults across all implementations. Issues warnings for any mismatches found.

Parameters
  • algorithm – Algorithm name to validate

  • verbose – If True, log all parameter signatures before comparison

Returns

True if all signatures match, False otherwise

Raises

ValueError – If algorithm is not registered

class ejkernel.kernels._registry.KernelSpec(platform: Platform, backend: Backend, algorithm: str, implementation: Callable, priority: int = 0)[source]#

Bases: object

Specification for a registered kernel implementation.

platform#

The implementation platform (triton, pallas, cuda, xla)

Type

ejkernel.kernels._registry.Platform

backend#

Target hardware backend (gpu, tpu, cpu, any)

Type

ejkernel.kernels._registry.Backend

algorithm#

Algorithm name (e.g., ‘flash_attention’)

Type

str

implementation#

The actual kernel function

Type

collections.abc.Callable

priority#

Selection priority (higher values preferred)

Type

int

algorithm: str#
backend: Backend#
implementation: Callable#
platform: Platform#
priority: int = 0#
class ejkernel.kernels._registry.Platform(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Supported kernel implementation platforms.

CUDA = 'cuda'#
PALLAS = 'pallas'#
TRITON = 'triton'#
XLA = 'xla'#