ejkernel.kernels._registry#
Kernel registry system for managing multi-platform implementations.
- class ejkernel.kernels._registry.Backend(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumTarget hardware backends for kernel execution.
- ANY = 'any'#
- CPU = 'cpu'#
- GPU = 'gpu'#
- TPU = 'tpu'#
- class ejkernel.kernels._registry.KernelRegistry[source]#
Bases:
objectRegistry for managing kernel implementations across platforms and backends.
Supports registering multiple implementations of the same algorithm for different platforms and backends, with priority-based selection.
Example
>>> registry = KernelRegistry() >>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU) ... def flash_attention_triton(q, k, v): ... >>> >>> >>> impl = registry.get("flash_attention", platform="triton", backend="gpu")
- get(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']]] = None, backend: Optional[Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']]] = None) Callable[source]#
Retrieve the best matching kernel implementation.
Searches for implementations matching the specified algorithm, platform, and backend. Returns the highest priority match. If backend is not specified, any backend will match. Backend.ANY implementations match all backend queries.
- Parameters
algorithm – Algorithm name to look up
platform – Optional platform filter
backend – Optional backend filter
- Returns
The matching kernel implementation function
- Raises
ValueError – If no matching implementation is found
Example
>>> impl = registry.get("flash_attention", platform="triton", backend="gpu") >>> result = impl(q, k, v)
- list_algorithms() list[str][source]#
List all registered algorithm names.
- Returns
Sorted list of algorithm names
- list_implementations(algorithm: str) list[ejkernel.kernels._registry.KernelSpec][source]#
List all implementations for a given algorithm.
- Parameters
algorithm – Algorithm name to query
- Returns
List of KernelSpec objects, sorted by priority (descending)
- register(algorithm: str, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla']], backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']], priority: int = 0) Callable[[F], F][source]#
Decorator to register a kernel implementation.
- Parameters
algorithm – Name of the algorithm (e.g., ‘flash_attention’)
platform – Implementation platform
backend – Target hardware backend
priority – Selection priority (default: 0). Higher values are preferred.
- Returns
Decorator function that registers the kernel and returns it unchanged
Example
>>> @registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10) ... def flash_attention_impl(q, k, v): ... return compute_attention(q, k, v)
- validate_signatures(algorithm: str | None, verbose: bool = False) bool[source]#
Validate that all implementations of an algorithm have matching signatures.
Compares parameter names, order, and defaults across all implementations. Issues warnings for any mismatches found.
- Parameters
algorithm – Algorithm name to validate
verbose – If True, log all parameter signatures before comparison
- Returns
True if all signatures match, False otherwise
- Raises
ValueError – If algorithm is not registered
- class ejkernel.kernels._registry.KernelSpec(platform: Platform, backend: Backend, algorithm: str, implementation: Callable, priority: int = 0)[source]#
Bases:
objectSpecification for a registered kernel implementation.
- platform#
The implementation platform (triton, pallas, cuda, xla)
- backend#
Target hardware backend (gpu, tpu, cpu, any)
- algorithm#
Algorithm name (e.g., ‘flash_attention’)
- Type
str
- implementation#
The actual kernel function
- priority#
Selection priority (higher values preferred)
- Type
int
- algorithm: str#
- priority: int = 0#