ejkernel.ops - Kernel Framework and Configuration Management#
Overview#
The ejkernel.ops package provides a sophisticated framework for implementing high-performance JAX operations with automatic configuration management, caching, and autotuning. It’s the foundation upon which ejkernel.modules operations are built.
Key Features:
Abstract kernel interface with platform-specific implementations
Multi-tier configuration selection and caching
Automatic performance autotuning
Batch processing utilities (vmap/pmap with shared config)
Custom gradient support via VJP
Architecture#
┌─────────────────────────────────────────────────────────────────┐
│ User Code │
│ kernel(args) or executor(kernel, args) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Executor │
│ Coordinates config selection, kernel dispatch, and caching │
└─────────────────────────────────────────────────────────────────┘
│
┌───────────────┴───────────────┐
▼ ▼
┌─────────────────────────┐ ┌─────────────────────────┐
│ ConfigSelectorChain │ │ Kernel │
│ │ │ │
│ 1. ConfigCache │ │ - heuristic_cfg() │
│ 2. PersistentCache │ │ - config_space() │
│ 3. Heuristic │ │ - run() │
│ 4. Autotuning │ │ - run_{platform}() │
└─────────────────────────┘ └─────────────────────────┘
Kernel: The Building Block#
A Kernel is an abstract base class for implementing operations. Each kernel defines:
How to run the operation (
run()or platform-specific methods)How to generate default configs (
heuristic_cfg())What configs to explore (
config_space()for autotuning)
Basic Kernel Implementation#
from dataclasses import dataclass
from ejkernel.ops import Kernel
import jax.numpy as jnp
# 1. Define your configuration
@dataclass(frozen=True)
class MyConfig:
block_size: int = 128
num_warps: int = 4
algorithm: str = "default"
# 2. Implement the kernel
class MyKernel(Kernel[MyConfig, jnp.ndarray]):
def __init__(self):
super().__init__(name="my_kernel")
def run(self, a, b, *, cfg: MyConfig, **kwargs) -> jnp.ndarray:
"""Default implementation."""
# Your operation logic here
return jnp.dot(a, b)
def heuristic_cfg(self, inv: "Invocation") -> MyConfig:
"""Generate config based on input shapes."""
# inv.args contains the input arrays
a = inv.args[0]
block = 128 if a.shape[0] >= 128 else 64
return MyConfig(block_size=block)
Platform-Specific Implementations#
Kernels can have different implementations for different platforms:
class AttentionKernel(Kernel[AttentionConfig, jnp.ndarray]):
def run_triton(self, q, k, v, *, cfg, **kwargs):
"""Triton GPU implementation."""
return triton_flash_attention(q, k, v, **cfg.triton_params)
def run_pallas(self, q, k, v, *, cfg, **kwargs):
"""Pallas TPU implementation."""
return pallas_attention(q, k, v, **cfg.pallas_params)
def run_xla(self, q, k, v, *, cfg, **kwargs):
"""XLA fallback implementation."""
return xla_attention(q, k, v)
def run(self, q, k, v, *, cfg, **kwargs):
"""Default fallback."""
return self.run_xla(q, k, v, cfg=cfg, **kwargs)
The executor automatically selects the appropriate platform method based on available hardware.
Custom Gradients (VJP)#
Kernels can define custom backward passes:
class MyKernel(Kernel[MyConfig, jnp.ndarray]):
def run(self, a, b, *, cfg, **kwargs):
return custom_forward(a, b, cfg)
def custom_vjp(self, primals, cfg, kwargs, cotangents):
"""Custom backward pass."""
a, b = primals
ct = cotangents[0]
grad_a = custom_grad_a(a, b, ct, cfg)
grad_b = custom_grad_b(a, b, ct, cfg)
return (grad_a, grad_b)
Invocation: Capturing Call Context#
An Invocation captures everything about a kernel call:
from ejkernel.ops import Invocation
# Created automatically by the executor
inv = Invocation(
kernel=my_kernel,
args=(a, b), # Positional arguments
kwargs={'scale': 1.0}, # Keyword arguments
metadata={...} # Shapes, dtypes, device info
)
# Access abstracted shapes (for config lookup)
inv.abstract_args # JAX ShapeDtypeStruct versions
inv.device_fingerprint # Unique device identifier
Executor: Running Kernels#
The Executor is the main interface for running kernels with automatic config selection:
from ejkernel.ops import Executor, ConfigSelectorChain, ConfigCache
# Create the execution stack
cache = ConfigCache()
selector = ConfigSelectorChain(cache=cache)
executor = Executor(selector)
# Run a kernel
result = executor(my_kernel, a, b)
# Run with explicit config
result = executor(my_kernel, a, b, cfg=my_config)
How Config Selection Works#
When you call executor(kernel, args):
Check ConfigCache - Fast in-memory lookup
Check PersistentCache - Disk-based storage for cross-session persistence
Use Heuristic - Call
kernel.heuristic_cfg()for shape-based defaultAutotune - If enabled, benchmark multiple configs and cache the best
Configuration Selection and Caching#
ConfigCache (In-Memory)#
Fast lookup for repeated calls with same shapes:
from ejkernel.ops import ConfigCache
cache = ConfigCache()
# Store a config
cache.put(key, config)
# Retrieve
config = cache.get(key) # Returns None if not found
# Check existence
if key in cache:
...
PersistentCache (Disk-Based)#
Survives across sessions:
from ejkernel.ops import PersistentCache
# Stored in ~/.cache/ejkernel/ by default
persistent = PersistentCache(path="~/.cache/ejkernel/configs")
# Same interface as ConfigCache
persistent.put(key, config)
config = persistent.get(key)
ConfigSelectorChain#
Combines multiple selection strategies:
from ejkernel.ops import ConfigSelectorChain, ConfigCache, PersistentCache
selector = ConfigSelectorChain(
cache=ConfigCache(), # Level 1: In-memory
persistent=PersistentCache(), # Level 2: Disk
autotune=True # Level 3: Benchmark if not found
)
Temporary Overrides#
from ejkernel.ops import overlay_cache, policy_override
# Temporarily use different cache
with overlay_cache(my_custom_cache):
result = executor(kernel, args)
# Temporarily change autotuning policy
with policy_override(autotune=False):
result = executor(kernel, args) # No autotuning
Autotuning#
ejkernel can automatically find optimal configurations by benchmarking:
Defining Config Space#
class MyKernel(Kernel[MyConfig, jnp.ndarray]):
def config_space(self, inv: Invocation) -> list[MyConfig]:
"""Return configs to try during autotuning."""
return [
MyConfig(block_size=64, num_warps=2),
MyConfig(block_size=128, num_warps=4),
MyConfig(block_size=256, num_warps=8),
]
Manual Autotuning#
from ejkernel.ops import autotune, benchmark
# Autotune a specific kernel call
best_config = autotune(
kernel, a, b,
configs=kernel.config_space(inv),
num_warmup=3,
num_iters=10
)
# Simple function benchmarking
time_ms = benchmark(lambda: kernel(a, b), num_iters=100)
Batch Autotuning#
Autotune all recorded invocations at once:
from ejkernel.ops import autotune_recorded, record_invocation, get_invocations
# Enable recording
record_invocation(kernel, inv)
# ... run your model with various inputs ...
# Get all recorded invocations
invocations = get_invocations()
# Autotune all at once
results = autotune_recorded(invocations)
Batch Processing#
vmap_with_config#
Vectorized execution with shared config selection (avoids re-selecting config for each batch element):
from ejkernel.ops import vmap_with_config
# Standard vmap would select config per element
# vmap_with_config selects once and broadcasts
batched_result = vmap_with_config(
kernel,
in_axes=(0, 0), # Batch over first axis
)(batched_a, batched_b)
pmap_with_config#
Parallel execution across devices:
from ejkernel.ops import pmap_with_config
# Run kernel across multiple devices
parallel_result = pmap_with_config(
kernel,
axis_name='devices'
)(sharded_a, sharded_b)
Configuration Parameters#
FwdParams and BwdParams#
Standard parameter containers for attention kernels:
from ejkernel.ops import FwdParams, BwdParams
fwd = FwdParams(
q_blocksize=128, # Query block size
kv_blocksize=256, # Key-value block size
num_warps=4, # GPU warps
num_stages=2, # Pipeline stages
)
bwd = BwdParams(
q_blocksize=64,
kv_blocksize=128,
)
Utility Functions#
Device Detection#
from ejkernel.ops import device_kind, get_device_platform, device_fingerprint
# Get device type
device_kind() # Returns "gpu", "tpu", or "cpu"
# Get platform for dispatch
get_device_platform() # Returns "triton", "pallas", or "xla"
# Unique device identifier (for caching)
fingerprint = device_fingerprint()
Sharding Utilities#
from ejkernel.ops import sharding_fingerprint, default_key_builder_with_sharding
# Get fingerprint of array sharding
fingerprint = sharding_fingerprint(array)
# Build cache key including sharding info
key = default_key_builder_with_sharding(kernel, args, kwargs)
Serialization#
from ejkernel.ops import to_json, from_json, stable_json
# Serialize config to JSON
json_str = to_json(config)
# Deserialize
config = from_json(json_str, MyConfig)
# Stable JSON (deterministic ordering for cache keys)
stable_str = stable_json(config)
HLO Analysis#
from ejkernel.ops import (
find_labels_in_lowered,
extract_labels_from_hlo_text,
labels_to_configs
)
# Find labeled operations in lowered JAX code
lowered = jax.jit(fn).lower(args)
labels = find_labels_in_lowered(lowered)
# Extract from HLO text
labels = extract_labels_from_hlo_text(hlo_text)
# Convert labels to configs
configs = labels_to_configs(labels)
Labeling Operations#
from ejkernel.ops import label
# Add label for HLO analysis
@label("my_attention")
def my_attention(q, k, v):
return attention_impl(q, k, v)
Best Practices#
1. Use the Module API for Common Operations#
For standard attention operations, use ejkernel.modules which handles all the complexity:
# Recommended for most users
from ejkernel.modules import flash_attention
output = flash_attention(q, k, v, causal=True)
2. Implement Kernel Only When Needed#
Create custom kernels only for:
New operations not in
ejkernel.modulesOperations needing special configuration
Performance-critical custom implementations
3. Define Good Heuristics#
Good heuristic configs reduce the need for autotuning:
def heuristic_cfg(self, inv):
# Consider input shapes
q = inv.args[0]
batch, seq_len, heads, dim = q.shape
# Larger sequences benefit from larger blocks
if seq_len >= 2048:
block = 256
elif seq_len >= 512:
block = 128
else:
block = 64
return MyConfig(block_size=block)
4. Leverage Caching#
Configs are cached automatically. Ensure consistent key generation:
# Shapes and dtypes are used for cache keys
# Actual values don't matter, only structure
5. Profile Before Optimizing#
Use JAX’s built-in profiler before manual autotuning:
with jax.profiler.trace("/tmp/jax-trace"):
result = executor(kernel, args)
Type Variables#
Type Variable |
Description |
|---|---|
|
Configuration type (e.g., |
|
Output type (e.g., |
Module Structure#
ejkernel/ops/
├── __init__.py # Public API exports
├── core/
│ ├── kernel.py # Kernel, Invocation
│ └── types.py # Cfg, Out type variables
├── config/
│ ├── cache.py # ConfigCache, overlay_cache
│ ├── persistent.py # PersistentCache
│ └── selection.py # ConfigSelectorChain, AutotunePolicy
├── execution/
│ ├── executor.py # Executor, ConfigChooser
│ ├── batch.py # vmap_with_config, pmap_with_config
│ └── tuning.py # Autotuner, benchmark utilities
├── registry.py # Invocation recording
└── utils/ # Utility functions