ejkernel.modules.base#

Base configuration for kernel modules.

Provides shared configuration infrastructure for kernel execution, but does NOT provide a base Kernel class. Each kernel module should implement its own custom Kernel subclass tailored to its specific needs.

Note

The create_default_executor() function is deprecated. Use Executor(ConfigSelectorChain(...)) directly instead. See the function documentation for migration examples.

class ejkernel.modules.base.KernelConfig(block_q: int = 128, block_k: int = 128, block_d: int = 64, num_warps: int = 4, num_stages: int = 2, platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = 'auto', backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']] = Backend.ANY, algorithm: str | None = None, priority: int = 0)[source]#

Bases: object

Configuration for kernel execution with block size tuning.

This is a shared configuration class that can be used by all kernel modules to specify block sizes for autotuning and performance optimization.

block_q#

Query block size for tiling

Type

int

block_k#

Key/value block size for tiling

Type

int

block_d#

Head dimension block size (if applicable)

Type

int

num_warps#

Number of warps for GPU kernels

Type

int

num_stages#

Number of pipeline stages for overlapping compute/memory

Type

int

platform#

Implementation platform (triton, pallas, cuda, xla, auto)

Type

Union[ejkernel.kernels._registry.Platform, Literal[‘triton’, ‘pallas’, ‘cuda’, ‘xla’, ‘auto’]]

backend#

Target hardware backend (gpu, tpu, cpu, any)

Type

Union[ejkernel.kernels._registry.Backend, Literal[‘gpu’, ‘tpu’, ‘cpu’, ‘any’]]

algorithm#

Specific algorithm variant if multiple exist

Type

str | None

priority#

Selection priority when multiple configs match

Type

int

algorithm: str | None = None#
backend: Union[Backend, Literal['gpu', 'tpu', 'cpu', 'any']] = 'any'#
block_d: int = 64#
block_k: int = 128#
block_q: int = 128#
num_stages: int = 2#
num_warps: int = 4#
platform: Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = 'auto'#
priority: int = 0#
ejkernel.modules.base.create_default_executor(persistent_cache_path: str | None = None, enable_autotuning: bool = True, warmup_iterations: int = 2, timing_iterations: int = 5) Executor[KernelConfig, jax.jaxlib._jax.Array | tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array]][source]#

Create a default executor with standard configuration.

Deprecated since version Use: Executor(ConfigSelectorChain(...)) directly instead of create_default_executor(). This function will be removed in a future version.

Sets up an executor with in-memory and optional persistent caching, and autotuning enabled by default.

Parameters
  • persistent_cache_path – Optional path for persistent cache storage. If None, only in-memory caching is used.

  • enable_autotuning – Whether to enable automatic performance tuning. When enabled, the executor will benchmark different implementations.

  • warmup_iterations – Number of warmup runs before timing measurements

  • timing_iterations – Number of timing iterations to average over

Returns

Configured Executor instance ready for kernel execution

Example

>>>
>>>
>>> executor = create_default_executor("/tmp/kernel_cache")
>>>
>>>
>>> from ejkernel.ops import Executor, ConfigSelectorChain, ConfigCache, AutotunePolicy, Tuner, PersistentCache
>>> executor = Executor(
...     ConfigSelectorChain(
...         cache=ConfigCache(),
...         policy=AutotunePolicy(allow_autotune=True),
...         tuner=Tuner(warmup=2, iters=5),
...         persistent=PersistentCache("/tmp/kernel_cache")
...     )
... )
>>>
>>> from ejkernel.modules import FlashAttention
>>> attn = FlashAttention()
>>> output = executor(attn, q, k, v, causal=True)
ejkernel.modules.base.detect_platform(algorithm: str, platform: Optional[Union[Platform, Literal['triton', 'pallas', 'cuda', 'xla', 'auto']]] = 'auto', maybe_pallas: bool = False) Platform[source]#

Detect the best platform for a given algorithm.

Automatically selects the optimal platform based on:
  1. Explicit platform request (if provided)

  2. Available implementations for the algorithm

  3. Current JAX backend (GPU/TPU/CPU)

  4. Platform-specific optimizations

The selection priority:
  • GPU backend + Triton available -> Triton (best GPU performance)

  • GPU backend + no Triton -> XLA

  • TPU backend -> Pallas (TPU-optimized) or XLA fallback

  • CPU backend -> XLA

Parameters
  • algorithm – The algorithm name to check for availability (e.g., “flash_attention”)

  • platform – Requested platform. Options: - “triton”: Triton GPU kernels (CUDA/ROCm) - “pallas”: Pallas kernels (TPU/GPU) - “cuda”: CUDA-specific implementations - “xla”: XLA compiler-based implementations - “auto” or None: Automatic selection (default)

Returns

The selected Platform enum value

Raises

ValueError – If the requested platform is not available for the algorithm

Example

>>>
>>> platform = detect_platform("flash_attention", platform="auto")
>>>
>>>
>>> platform = detect_platform("flash_attention", platform="triton")
>>>
>>>
>>> print(f"Selected: {detect_platform('ring_attention')}")

Note

Triton is preferred on NVIDIA GPUs for best performance, but XLA provides broader compatibility across hardware backends.