ejkernel.modules.base#
Base configuration for kernel modules.
Provides shared configuration infrastructure for kernel execution, but does NOT provide a base Kernel class. Each kernel module should implement its own custom Kernel subclass tailored to its specific needs.
Note
The create_default_executor() function is deprecated. Use
Executor(ConfigSelectorChain(...)) directly instead. See the
function documentation for migration examples.
- class ejkernel.modules.base.KernelConfig(block_q: int = 128, block_k: int = 128, block_d: int = 64, num_warps: int = 4, num_stages: int = 2, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = 'auto', backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']] = Backend.ANY, algorithm: str | None = None, priority: int = 0)[source]#
Bases:
objectConfiguration for kernel execution with block size tuning.
This is a shared configuration class that can be used by all kernel modules to specify block sizes for autotuning and performance optimization.
- block_q#
Query block size for tiling
- Type
int
- block_k#
Key/value block size for tiling
- Type
int
- block_d#
Head dimension block size (if applicable)
- Type
int
- num_warps#
Number of warps for GPU kernels
- Type
int
- num_stages#
Number of pipeline stages for overlapping compute/memory
- Type
int
- platform#
Implementation platform (triton, pallas, cuda, xla, auto)
- Type
Union[ejkernel.kernels._registry.Platform, Literal[‘triton’, ‘pallas’, ‘cuda’, ‘xla’, ‘auto’]]
- backend#
Target hardware backend (gpu, tpu, cpu, any)
- Type
Union[ejkernel.kernels._registry.Backend, Literal[‘gpu’, ‘tpu’, ‘cpu’, ‘any’]]
- algorithm#
Specific algorithm variant if multiple exist
- Type
str | None
- priority#
Selection priority when multiple configs match
- Type
int
- algorithm: str | None = None#
- block_d: int = 64#
- block_k: int = 128#
- block_q: int = 128#
- num_stages: int = 2#
- num_warps: int = 4#
- priority: int = 0#
- ejkernel.modules.base.create_default_executor(persistent_cache_path: str | None = None, enable_autotuning: bool = True, warmup_iterations: int = 2, timing_iterations: int = 5) Executor[KernelConfig, jax.jaxlib._jax.Array | tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array]][source]#
Create a default executor with standard configuration.
Deprecated since version Use:
Executor(ConfigSelectorChain(...))directly instead ofcreate_default_executor(). This function will be removed in a future version.Sets up an executor with in-memory and optional persistent caching, and autotuning enabled by default.
- Parameters
persistent_cache_path – Optional path for persistent cache storage. If None, only in-memory caching is used.
enable_autotuning – Whether to enable automatic performance tuning. When enabled, the executor will benchmark different implementations.
warmup_iterations – Number of warmup runs before timing measurements
timing_iterations – Number of timing iterations to average over
- Returns
Configured Executor instance ready for kernel execution
Example
>>> >>> >>> executor = create_default_executor("/tmp/kernel_cache") >>> >>> >>> from ejkernel.ops import Executor, ConfigSelectorChain, ConfigCache, AutotunePolicy, Tuner, PersistentCache >>> executor = Executor( ... ConfigSelectorChain( ... cache=ConfigCache(), ... policy=AutotunePolicy(allow_autotune=True), ... tuner=Tuner(warmup=2, iters=5), ... persistent=PersistentCache("/tmp/kernel_cache") ... ) ... ) >>> >>> from ejkernel.modules import FlashAttention >>> attn = FlashAttention() >>> output = executor(attn, q, k, v, causal=True)
- ejkernel.modules.base.detect_platform(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']]] = 'auto', maybe_pallas: bool = False) Platform[source]#
Detect the best platform for a given algorithm.
- Automatically selects the optimal platform based on:
Explicit platform request (if provided)
Available implementations for the algorithm
Current JAX backend (GPU/TPU/CPU)
Platform-specific optimizations
- The selection priority:
GPU backend + Triton available -> Triton (best GPU performance)
GPU backend + no Triton -> XLA
TPU backend -> Pallas (TPU-optimized) or XLA fallback
CPU backend -> XLA
- Parameters
algorithm – The algorithm name to check for availability (e.g., “flash_attention”)
platform – Requested platform. Options: - “triton”: Triton GPU kernels (CUDA/ROCm) - “pallas”: Pallas kernels (TPU/GPU) - “cuda”: CUDA-specific implementations - “xla”: XLA compiler-based implementations - “auto” or None: Automatic selection (default)
- Returns
The selected Platform enum value
- Raises
ValueError – If the requested platform is not available for the algorithm
Example
>>> >>> platform = detect_platform("flash_attention", platform="auto") >>> >>> >>> platform = detect_platform("flash_attention", platform="triton") >>> >>> >>> print(f"Selected: {detect_platform('ring_attention')}")
Note
Triton is preferred on NVIDIA GPUs for best performance, but XLA provides broader compatibility across hardware backends.