Kernel Registry System Analysis#
Overview#
The kernel registry system is the heart of ejKernel’s multi-backend architecture, providing a centralized mechanism for managing multiple implementations of the same algorithm across different platforms and backends. It enables automatic platform-specific kernel selection and ensures consistency across implementations.
Core Components#
1. Platform and Backend Enumerations#
class Platform(str, Enum):
TRITON = "triton" # GPU kernels via Triton
PALLAS = "pallas" # TPU/GPU kernels via Pallas
CUDA = "cuda" # Direct CUDA implementations
XLA = "xla" # XLA compiler-based
class Backend(str, Enum):
GPU = "gpu"
TPU = "tpu"
CPU = "cpu"
ANY = "any" # Platform-agnostic
2. KernelSpec Dataclass#
@dataclass(frozen=True)
class KernelSpec:
platform: Platform
backend: Backend
algorithm: str
implementation: Callable
priority: int = 0 # Higher values are preferred
The KernelSpec encapsulates all metadata about a kernel implementation, including its priority for selection when multiple implementations are available.
3. KernelRegistry Class#
The central registry managing all kernel implementations with sophisticated selection logic.
Key Methods#
1. Registration#
@kernel_registry.register("flash_attention", Platform.TRITON, Backend.GPU, priority=10)
def flash_attention_triton(q, k, v, **kwargs):
# Triton implementation
pass
Features:
Decorator-based registration for clean syntax
Priority-based ordering (higher values preferred)
Automatic sorting of implementations by priority
Prevents duplicate registrations for same (algorithm, platform, backend) tuple
2. Kernel Retrieval#
def get(self, algorithm: str, platform: Platform | str | None = None,
backend: Backend | str | None = None) -> Callable:
"""
Retrieves the best matching kernel implementation.
Selection hierarchy:
1. Exact match: (algorithm, platform, backend)
2. Platform match with Backend.ANY
3. Auto-detection based on JAX backend
4. Raise ValueError if no match found
"""
Selection Logic:
Explicit Request: If platform and backend specified, find exact match
Platform-Only: If only platform specified, match platform with any backend
Auto-Detection: If platform=”auto” or None, detect from JAX backend:
JAX backend “gpu” → prefer Triton, fallback to Pallas/XLA
JAX backend “tpu” → prefer Pallas, fallback to XLA
JAX backend “cpu” → use XLA
Priority Resolution: When multiple matches exist, select highest priority
3. Signature Validation#
def validate_signatures(self, algorithm: str | None = None, verbose: bool = False):
"""
Validates parameter consistency across all implementations of an algorithm.
Checks:
- Parameter names match
- Parameter order matches
- Parameter kinds match (positional, keyword-only, etc.)
- Default values match
- Type annotations are compatible
"""
Validation Process:
Groups all implementations by algorithm
For each algorithm with multiple implementations:
Extract signatures using
inspect.signatureNormalize type annotations (handles different import paths)
Compare parameter names, order, kinds, defaults
Report discrepancies if verbose=True
Raise ValueError on mismatch
4. Type Normalization#
def _normalize_type_string(type_annotation: Any) -> str:
"""
Normalizes type annotation strings for comparison.
Examples:
- 'jaxtyping.Float[Array, "..."]' → 'Float[Array, "..."]'
- 'ejkernel.ops.utils.datacarrier.FwdParams' → 'FwdParams'
- 'typing.Union[X, None]' → 'Optional[X]'
"""
This enables signature validation even when implementations use different import styles for the same types.
Platform Detection#
def detect_platform(algorithm: str, platform: str = "auto",
maybe_pallas: bool = False) -> Platform:
"""
Intelligently detects the best platform for the current hardware.
Detection logic:
1. If explicit platform requested, return it
2. Get JAX backend (GPU/TPU/CPU)
3. Check available implementations in registry
4. Apply platform preferences:
- GPU: Triton > Pallas > XLA
- TPU: Pallas > XLA
- CPU: XLA only
"""
Registration Patterns#
Basic Registration#
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_gpu(x, y, z):
return triton_implementation(x, y, z)
Multi-Backend Registration#
# GPU implementation
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU, priority=100)
def my_kernel_gpu(x, y, z):
return gpu_optimized(x, y, z)
# TPU implementation
@kernel_registry.register("my_kernel", Platform.PALLAS, Backend.TPU, priority=100)
def my_kernel_tpu(x, y, z):
return tpu_optimized(x, y, z)
# Universal fallback
@kernel_registry.register("my_kernel", Platform.XLA, Backend.ANY, priority=50)
def my_kernel_xla(x, y, z):
return xla_generic(x, y, z)
Platform-Specific Features#
@kernel_registry.register("attention", Platform.PALLAS, Backend.TPU)
def attention_tpu(query, key, value, **kwargs):
# TPU-specific: use segment IDs instead of attention masks
if "attention_mask" in kwargs:
segment_ids = mask_to_segment_ids(kwargs["attention_mask"])
kwargs["segment_ids"] = segment_ids
del kwargs["attention_mask"]
return pallas_attention(query, key, value, **kwargs)
Usage Examples#
Direct Registry Usage#
from ejkernel import kernel_registry, Platform, Backend
# Get specific implementation
flash_attn_gpu = kernel_registry.get(
algorithm="flash_attention",
platform=Platform.TRITON,
backend=Backend.GPU
)
# Auto-detect platform
flash_attn_auto = kernel_registry.get("flash_attention")
# Use the kernel
output = flash_attn_gpu(query, key, value, causal=True)
Module Integration#
class FlashAttention(Kernel):
def get_impl(self, cfg: FlashAttentionConfig):
"""Get implementation from registry based on config"""
return kernel_registry.get(
algorithm="flash_attention",
platform=detect_platform("flash_attention", cfg.platform),
backend=cfg.backend,
)
Design Decisions#
1. Priority-Based Selection#
Rationale: Different implementations may have different performance characteristics. Priority allows expressing preferences while maintaining flexibility.
Example: Triton implementations typically have priority=100 on GPU, while XLA fallbacks have priority=50.
2. Signature Validation#
Rationale: Ensures that different implementations of the same algorithm are drop-in replacements for each other.
Benefits:
Catches API mismatches early
Enforces consistency across backends
Simplifies testing and maintenance
3. Type Normalization#
Rationale: Different files may import the same type differently (e.g., absolute vs relative imports).
Solution: Strip module paths and normalize common patterns to enable comparison.
4. Backend.ANY#
Rationale: Some implementations are platform-specific but work across backends (e.g., XLA works on CPU/GPU/TPU).
Usage: Typically for fallback implementations that prioritize correctness over performance.
Thread Safety#
The registry implementation uses thread-safe operations:
Registration uses list operations which are thread-safe in Python
Retrieval only reads from immutable data structures
No mutable shared state after initialization
Error Handling#
Registration Errors#
# Duplicate registration attempt
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def impl1(): pass
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def impl2(): pass # Raises ValueError
Retrieval Errors#
# No matching implementation
kernel = kernel_registry.get("nonexistent_kernel") # Raises ValueError
# No implementation for requested platform
kernel = kernel_registry.get("cpu_only_kernel", platform=Platform.TRITON) # Raises ValueError
Performance Considerations#
Registration (One-time)#
O(n log n) sorting after each registration
Happens at module import time
No runtime overhead
Retrieval (Runtime)#
O(n) search through implementations
Typically n < 10 per algorithm
Result can be cached by caller
Extension Points#
Custom Platform#
class Platform(str, Enum):
TRITON = "triton"
PALLAS = "pallas"
CUDA = "cuda"
XLA = "xla"
CUSTOM = "custom" # Add new platform
Custom Selection Logic#
class CustomRegistry(KernelRegistry):
def get(self, algorithm, platform=None, backend=None):
# Add custom selection logic
if should_use_custom_logic():
return self.get_custom_impl(algorithm)
return super().get(algorithm, platform, backend)
Dynamic Registration#
def register_dynamic_kernel(algorithm, impl, platform, backend):
"""Register kernel at runtime"""
kernel_registry._implementations[algorithm].append(
KernelSpec(platform, backend, algorithm, impl, priority=0)
)
kernel_registry._implementations[algorithm].sort(
key=lambda x: x.priority, reverse=True
)
Best Practices#
1. Consistent Signatures#
Always ensure all implementations of an algorithm have identical signatures:
# Good - consistent signatures
@kernel_registry.register("matmul", Platform.TRITON, Backend.GPU)
def matmul_gpu(a: Array, b: Array, transpose_a: bool = False, transpose_b: bool = False):
pass
@kernel_registry.register("matmul", Platform.XLA, Backend.ANY)
def matmul_xla(a: Array, b: Array, transpose_a: bool = False, transpose_b: bool = False):
pass
2. Platform-Specific Optimizations#
Use platform-specific features when beneficial:
@kernel_registry.register("attention", Platform.TRITON, Backend.GPU)
def attention_gpu(q, k, v, **kwargs):
# Use Triton's shared memory optimizations
return triton_attention_kernel(q, k, v, **kwargs)
@kernel_registry.register("attention", Platform.PALLAS, Backend.TPU)
def attention_tpu(q, k, v, **kwargs):
# Use TPU's matrix units
return pallas_attention_kernel(q, k, v, **kwargs)
3. Graceful Degradation#
Always provide a fallback implementation:
@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=0)
def my_operation_fallback(x, y):
"""Basic implementation that works everywhere"""
return jnp.dot(x, y)
4. Documentation#
Document platform-specific limitations:
@kernel_registry.register("sparse_attention", Platform.PALLAS, Backend.TPU)
def sparse_attention_tpu(q, k, v, sparsity_mask):
"""
TPU implementation of sparse attention.
Note: TPU version requires sparsity_mask to be block-aligned
with blocks of size 128x128.
"""
pass
Testing Strategies#
1. Cross-Backend Validation#
def test_kernel_consistency():
"""Test that all backends produce identical results"""
algorithms = ["flash_attention", "page_attention"]
for algo in algorithms:
impls = kernel_registry.get_all_implementations(algo)
# Test with same inputs
inputs = generate_test_inputs(algo)
outputs = [impl(*inputs) for impl in impls]
# Verify all outputs are equivalent
for out in outputs[1:]:
assert_allclose(outputs[0], out, rtol=1e-5)
2. Signature Validation Fn#
def test_signature_consistency():
"""Ensure all implementations have consistent signatures"""
kernel_registry.validate_signatures(verbose=True)
3. Platform Detection#
def test_platform_detection():
"""Test automatic platform detection"""
with mock_jax_backend("gpu"):
platform = detect_platform("flash_attention")
assert platform == Platform.TRITON
with mock_jax_backend("tpu"):
platform = detect_platform("flash_attention")
assert platform == Platform.PALLAS
Conclusion#
The kernel registry system provides a robust, extensible foundation for multi-backend kernel management in ejKernel. Its key strengths include:
Flexibility: Easy to add new implementations and platforms
Consistency: Signature validation ensures API compatibility
Intelligence: Automatic platform detection and priority-based selection
Simplicity: Clean decorator-based registration API
Reliability: Multiple fallback mechanisms ensure availability
This design enables ejKernel to provide optimal performance across diverse hardware platforms while maintaining a consistent API for users.