Ops System Architecture Analysis#
Overview#
The ops system provides the core infrastructure for kernel execution in ejKernel. It implements a sophisticated multi-tier configuration management system with automatic optimization, caching, and device-aware execution. This system orchestrates the complete lifecycle of kernel invocation from configuration selection to execution with custom gradients.
Architecture Components#
ejkernel/ops/
├── config/ # Configuration management
│ ├── selection.py # Multi-tier config selection
│ ├── cache.py # In-memory caching
│ └── persistent.py # Disk-based persistence
├── core/ # Base kernel classes
│ └── kernel.py # Kernel base class and invocation
├── execution/ # Execution orchestration
│ └── executor.py # Main executor implementation
└── utils/ # Utilities
├── fingerprint.py # Device fingerprinting
└── datacarrier.py # Configuration data structures
Core Components#
1. Kernel Base Class#
The Kernel[Cfg, Out] generic class provides the foundation for all operations:
class Kernel(Generic[Cfg, Out]):
"""Base class for all kernel operations"""
def __init__(self, op_id: str):
self.op_id = op_id
# Required methods
def run(self, *args, cfg: Cfg, **kwargs) -> Out:
"""Core operation execution with configuration"""
raise NotImplementedError
def heuristic_cfg(self, inv: Invocation[Cfg, Out]) -> Cfg:
"""Default configuration for given invocation"""
raise NotImplementedError
# Optional methods
def prepare(self, *args, **kwargs) -> tuple[tuple, dict]:
"""Preprocess arguments before execution"""
return args, kwargs
def candidate_cfgs(self, inv: Invocation[Cfg, Out]) -> Iterable[Cfg]:
"""Configurations for autotuning"""
return []
# Custom gradient support
def fwd_with_residuals(self, *args, cfg: Cfg, **kwargs) -> tuple[Out, Any]:
"""Forward pass with residuals for custom VJP"""
output = self.run(*args, cfg=cfg, **kwargs)
return output, None
def vjp(self, residuals: Any, y: Out, dy: Any, *args, cfg: Cfg, **kwargs):
"""Backward pass for custom gradients"""
raise NotImplementedError
2. Platform-Specific Method Dispatch#
The system supports hierarchical method dispatch for platform-specific optimizations:
# Method resolution order (most specific to least specific):
# 1. run_shard_map_gpu (context + platform)
# 2. run_shard_map (context only)
# 3. run_gpu (platform only)
# 4. run (generic fallback)
class MyKernel(Kernel):
def run(self, x, cfg):
"""Generic implementation"""
return generic_impl(x, cfg)
def run_gpu(self, x, cfg):
"""GPU-optimized implementation"""
return gpu_optimized_impl(x, cfg)
def run_shard_map(self, x, cfg):
"""Distributed implementation"""
return distributed_impl(x, cfg)
def run_shard_map_gpu(self, x, cfg):
"""Distributed GPU-optimized implementation"""
return distributed_gpu_impl(x, cfg)
3. Invocation Dataclass#
Captures complete execution context:
@dataclasses.dataclass(frozen=True)
class Invocation(Generic[Cfg, Out]):
"""Complete execution context for a kernel invocation"""
# Core identification
op_id: str # Kernel identifier
args: tuple[Any, ...] # Positional arguments
kwargs: Mapping[str, Any] # Keyword arguments
# Configuration
override_cfg: Cfg | None = None # Explicit configuration override
# Execution context
method: str | None = None # e.g., "shard_map"
stamp: bool = True # Enable profiling metadata
batch_axes: Mapping[str, int] | None = None # Batching information
# Distributed execution
mesh: jax.sharding.Mesh | None = None
in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None
out_specs: jax.sharding.PartitionSpec | None = None
check_vma: bool = False # Verify memory alignment
@property
def call_key(self) -> str:
"""16-character hash for caching based on arg shapes/types"""
return short_hash(abstractify((self.args, self.kwargs)))
@property
def versioned_op_id(self) -> str:
"""op_id with version suffix for cache invalidation"""
return f"{self.op_id}@v1"
Configuration Selection System#
ConfigSelectorChain#
Implements a sophisticated 7-tier fallback hierarchy for configuration selection:
class ConfigSelectorChain(Generic[Cfg]):
"""Multi-tier configuration selection with fallback"""
def __init__(self,
cache: ConfigCache[Cfg] | None = None,
policy: AutotunePolicy | None = None,
tuner: Tuner[Cfg] | None = None,
persistent: PersistentCache[Cfg] | None = None,
overlay: dict | None = None):
self.cache = cache or ConfigCache()
self.policy = policy or AutotunePolicy()
self.tuner = tuner or Tuner()
self.persistent = persistent
self.overlay = overlay or {}
def choose(self, inv: Invocation[Cfg, Out], kernel: Kernel[Cfg, Out]) -> Cfg:
"""
Configuration selection priority:
1. Override: Explicit cfg in invocation
2. Overlay: Temporary context-specific configurations
3. Memory Cache: Fast in-memory lookup
4. Persistent Cache: Disk-based storage
5. Autotune: Benchmark candidate configurations
6. Heuristics: Kernel-provided defaults
7. Error: No configuration available
"""
Selection Flow#
def choose(self, inv: Invocation, kernel: Kernel) -> Cfg:
device = device_fingerprint()
# 1. Override - highest priority
if inv.override_cfg is not None:
return inv.override_cfg
# 2. Overlay - temporary overrides
cache_key = (device, inv.versioned_op_id, inv.call_key)
if cache_key in self.overlay:
return self.overlay[cache_key]
# 3. Memory Cache - fast lookup
if self.cache:
cfg = self.cache.get(device, inv.versioned_op_id, inv.call_key)
if cfg is not None:
return cfg
# 4. Persistent Cache - disk storage
if self.persistent:
cfg = self.persistent.get(device, inv.versioned_op_id, inv.call_key)
if cfg is not None:
self.cache.put(device, inv.versioned_op_id, inv.call_key, cfg)
return cfg
# 5. Autotune - benchmark candidates
if self.policy.allow_autotune:
candidates = get_candidates(kernel, inv, get_device_platform())
if candidates:
best_cfg = self.tuner.autotune(make_fn, args, kwargs, candidates)
# Cache the result
self.cache.put(device, inv.versioned_op_id, inv.call_key, best_cfg)
if self.persistent:
self.persistent.put(device, inv.versioned_op_id, inv.call_key, best_cfg)
return best_cfg
# 6. Heuristics - default configuration
if self.policy.allow_heuristics:
return kernel.heuristic_cfg(inv)
# 7. Error - no configuration available
raise ValueError(f"No configuration available for {inv.op_id}")
Tuner Class#
Sophisticated performance benchmarking system:
class Tuner(Generic[Cfg]):
"""Performance benchmarking for configuration selection"""
def __init__(self, warmup: int = 5, iters: int = 100):
self.warmup = warmup
self.iters = iters
def measure(self, fn: Callable, *args, **kwargs) -> float:
"""
Measure execution time with proper handling of JAX specifics:
- Deep-flattens args/kwargs to separate arrays from constants
- Handles JAX tracers by converting to concrete arrays
- Supports backward pass validation
- Falls back gracefully for non-transformable functions
"""
# Flatten to separate arrays from constants
leaves, treedef = tree_flatten((args, kwargs))
array_leaves = [x for x in leaves if isinstance(x, Array)]
# Convert tracers to concrete arrays
concrete_arrays = [np.array(x) if hasattr(x, '__jax_array__')
else x for x in array_leaves]
# Rebuild args/kwargs with concrete arrays
concrete_leaves = [concrete_arrays.pop(0) if isinstance(x, Array)
else x for x in leaves]
args, kwargs = tree_unflatten(treedef, concrete_leaves)
# JIT compile if possible
try:
jitted_fn = jax.jit(fn)
except:
jitted_fn = fn
# Warmup runs
for _ in range(self.warmup):
jitted_fn(*args, **kwargs).block_until_ready()
# Timed runs
start = time.perf_counter()
for _ in range(self.iters):
jitted_fn(*args, **kwargs).block_until_ready()
end = time.perf_counter()
return (end - start) / self.iters
def autotune(self, make_fn: Callable[[Cfg], Callable],
args: tuple, kwargs: dict,
candidates: list[Cfg]) -> Cfg:
"""Benchmark all candidates and return the fastest"""
best_cfg = None
best_time = float('inf')
for cfg in candidates:
try:
fn = make_fn(cfg)
time = self.measure(fn, *args, **kwargs)
if time < best_time:
best_time = time
best_cfg = cfg
if os.getenv("EJKERNEL_LOG_AUTOTUNE"):
print(f"Config {cfg}: {time:.6f}s")
except Exception as e:
if os.getenv("EJKERNEL_LOG_AUTOTUNE"):
print(f"Config {cfg} failed: {e}")
return best_cfg
Backward Pass Validation#
def validate_backward(fn, args, kwargs):
"""Validate gradient computation"""
if not kwargs.get("_ejk_validate_backward", False):
return True
try:
# Extract array leaves for gradient computation
leaves, _ = tree_flatten((args, kwargs))
array_indices = [i for i, x in enumerate(leaves)
if isinstance(x, Array)]
# Create value_and_grad function
def wrapped(*array_args):
# Rebuild full args/kwargs
full_leaves = list(leaves)
for i, idx in enumerate(array_indices):
full_leaves[idx] = array_args[i]
args, kwargs = tree_unflatten(treedef, full_leaves)
return fn(*args, **kwargs).sum()
value_and_grad_fn = jax.value_and_grad(wrapped, argnums=range(len(array_indices)))
# Test gradient computation
array_args = [leaves[i] for i in array_indices]
_, grads = value_and_grad_fn(*array_args)
return True
except:
return False
Executor#
The main orchestrator for kernel execution:
class Executor(Generic[Cfg, Out]):
"""Orchestrates complete kernel execution pipeline"""
def __init__(self, chooser: ConfigSelectorChain[Cfg]):
self.chooser = chooser
def __call__(self, kernel: Kernel[Cfg, Out], *args,
cfg: Cfg | None = None,
stamp: bool = True,
method: str | None = None,
mesh: Mesh | None = None,
in_specs: tuple | None = None,
out_specs: Any | None = None,
check_vma: bool = False,
**kwargs) -> Out:
"""
Complete execution flow:
1. Preprocess arguments via kernel.prepare()
2. Create Invocation with metadata
3. Select configuration via chooser.choose()
4. Setup custom VJP if kernel implements it
5. Add profiling metadata if stamp=True
6. Execute with chosen configuration
7. Record invocation if EJKERNEL_OPS_RECORD=1
"""
Execution Flow#
def execute(self, kernel, *args, **kwargs):
# 1. Preprocessing
args, kwargs = kernel.prepare(*args, **kwargs)
# 2. Create invocation context
inv = Invocation(
op_id=kernel.op_id,
args=args,
kwargs=kwargs,
override_cfg=kwargs.pop("_cfg", None),
method=method,
stamp=stamp,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma
)
# 3. Configuration selection
cfg = self.chooser.choose(inv, kernel)
# 4. Determine execution method
platform = get_device_platform()
context = inv.method or "default"
# Try methods in order of specificity
method_name = f"run_{context}_{platform}"
if not hasattr(kernel, method_name):
method_name = f"run_{context}"
if not hasattr(kernel, method_name):
method_name = f"run_{platform}"
if not hasattr(kernel, method_name):
method_name = "run"
# 5. Custom VJP setup
if _has_custom_vjp(kernel, platform, context):
runner = create_custom_vjp_wrapper(kernel, cfg, platform, context)
else:
runner = getattr(kernel, method_name)
# 6. Add profiling metadata
if stamp:
runner = add_profiling_metadata(runner, inv, cfg)
# 7. Execute
output = runner(*args, cfg=cfg, **kwargs)
# 8. Record invocation
if os.getenv("EJKERNEL_OPS_RECORD"):
record_invocation(inv, cfg, output)
return output
Custom VJP Integration#
def create_custom_vjp_wrapper(kernel, cfg, platform, context):
"""Create custom VJP wrapper for gradient computation"""
@jax.custom_vjp
def wrapped(*array_args, **array_kwargs):
# Forward pass
return kernel.fwd_with_residuals(*array_args, cfg=cfg, **array_kwargs)[0]
def fwd(*array_args, **array_kwargs):
output, residuals = kernel.fwd_with_residuals(*array_args, cfg=cfg, **array_kwargs)
return output, (residuals, array_args, array_kwargs)
def bwd(ctx, dy):
residuals, array_args, array_kwargs = ctx
grads = kernel.vjp(residuals, None, dy, *array_args, cfg=cfg, **array_kwargs)
# Map gradients to correct positions
grad_dict = {}
for i, arg in enumerate(array_args):
grad_dict[id(arg)] = grads[i] if i < len(grads) else None
return tuple(grad_dict.get(id(arg), None) for arg in array_args)
wrapped.defvjp(fwd, bwd)
return wrapped
Caching System#
In-Memory Cache#
class ConfigCache(Generic[Cfg]):
"""Thread-safe in-memory configuration cache"""
def __init__(self):
self._cache: dict[tuple[str, str, str], Cfg] = {}
self._lock = threading.RLock()
def get(self, dev: str, op_id: str, call_key: str) -> Cfg | None:
"""Thread-safe cache lookup"""
with self._lock:
return self._cache.get((dev, op_id, call_key))
def put(self, dev: str, op_id: str, call_key: str, cfg: Cfg):
"""Thread-safe cache insertion"""
with self._lock:
self._cache[(dev, op_id, call_key)] = cfg
def clear(self):
"""Clear all cached entries"""
with self._lock:
self._cache.clear()
def size(self) -> int:
"""Number of cached entries"""
with self._lock:
return len(self._cache)
Persistent Cache#
class PersistentCache(Generic[Cfg]):
"""Disk-based configuration persistence"""
def __init__(self, opname: str, path: str | None = None,
loader: Callable | None = None,
dumper: Callable | None = None,
cfg_type: type | None = None):
"""
Args:
opname: Operation name for default path
path: Custom cache file path
loader: Custom deserialization function
dumper: Custom serialization function
cfg_type: Configuration type for automatic ser/deser
"""
if path is None:
# Default: ~/ejkernel-persistent-cache/{opname}.json
cache_dir = Path.home() / "ejkernel-persistent-cache"
cache_dir.mkdir(parents=True, exist_ok=True)
self.path = cache_dir / f"{opname}.json"
else:
self.path = Path(path)
self.loader = loader or self._default_loader
self.dumper = dumper or self._default_dumper
self.cfg_type = cfg_type
self._lock = threading.RLock()
self._data = self._load()
def _default_loader(self, data: dict) -> Cfg:
"""Default deserialization"""
if self.cfg_type:
if dataclasses.is_dataclass(self.cfg_type):
return self.cfg_type(**data)
elif hasattr(self.cfg_type, 'model_validate'): # Pydantic
return self.cfg_type.model_validate(data)
return data
def _default_dumper(self, cfg: Cfg) -> dict:
"""Default serialization"""
if dataclasses.is_dataclass(cfg):
return dataclasses.asdict(cfg)
elif hasattr(cfg, 'model_dump'): # Pydantic
return cfg.model_dump()
return cfg
def _load(self) -> dict:
"""Load cache from disk"""
if self.path.exists():
try:
with open(self.path, 'r') as f:
return json.load(f)
except:
return {}
return {}
def _save(self):
"""Atomic save to disk"""
with tempfile.NamedTemporaryFile('w', dir=self.path.parent,
delete=False) as tmp:
json.dump(self._data, tmp, indent=2)
tmp.flush()
os.fsync(tmp.fileno())
os.replace(tmp.name, self.path) # Atomic on POSIX
def get(self, dev: str, op_id: str, call_key: str) -> Cfg | None:
"""Retrieve configuration from disk"""
with self._lock:
key = f"{dev}|{op_id}|{call_key}"
data = self._data.get(key)
if data is not None:
return self.loader(data)
return None
def put(self, dev: str, op_id: str, call_key: str, cfg: Cfg):
"""Store configuration to disk"""
with self._lock:
key = f"{dev}|{op_id}|{call_key}"
self._data[key] = self.dumper(cfg)
self._save()
Utility Functions#
Device Fingerprinting#
def device_fingerprint(dev: Device | None = None) -> str:
"""
Generate unique device identifier
Returns:
'device_kind|platform_version'
Examples:
'gpu|cuda_12.0'
'tpu|v4'
'cpu|'
"""
if dev is None:
dev = jax.devices()[0]
device_kind = dev.platform
platform_version = ""
if device_kind == "gpu":
# Get CUDA/ROCm version
platform_version = get_cuda_version()
elif device_kind == "tpu":
# Get TPU generation
platform_version = get_tpu_version()
return f"{device_kind}|{platform_version}"
def get_device_platform(dev: Device | None = None) -> str:
"""Get device platform: 'gpu', 'tpu', 'cpu', or 'unknown'"""
if dev is None:
dev = jax.devices()[0]
return dev.platform
Stable Hashing#
def abstractify(pytree: Any) -> Any:
"""Convert arrays to ShapeDtypeStruct for caching"""
def _abstractify_leaf(x):
if isinstance(x, Array):
return jax.ShapeDtypeStruct(x.shape, x.dtype)
return x
return tree_map(_abstractify_leaf, pytree)
def short_hash(obj: Any) -> str:
"""Generate 16-character hash for object"""
json_str = stable_json(obj)
return hashlib.sha256(json_str.encode()).hexdigest()[:16]
def stable_json(obj: Any) -> str:
"""
Deterministic JSON serialization
Handles:
- Functions: (module, name, source_position)
- functools.partial: (func, args, kwargs)
- Dataclasses: asdict conversion
- Pydantic models: model_dump
- JAX/NumPy arrays: (shape, dtype)
- Primitives: direct serialization
"""
if callable(obj):
return json.dumps({
"type": "function",
"module": obj.__module__,
"name": obj.__name__,
"position": str(inspect.getsourcefile(obj))
}, sort_keys=True)
elif isinstance(obj, functools.partial):
return json.dumps({
"type": "partial",
"func": stable_json(obj.func),
"args": [stable_json(arg) for arg in obj.args],
"kwargs": {k: stable_json(v) for k, v in obj.keywords.items()}
}, sort_keys=True)
# ... handle other types
return json.dumps(obj, sort_keys=True)
Configuration Data Carriers#
@dataclass
class FwdParams:
"""Forward pass configuration parameters"""
# Block sizes
blocksize_m: int | None = None
blocksize_k: int | None = None
blocksize_n: int | None = None
q_blocksize: int | None = None
kv_blocksize: int | None = None
# Thread configuration
blocksize_heads: int | None = None
blocksize_keys: int | None = None
num_key_splits: int | None = None
num_warps: int | None = None
num_stages: int | None = None
def __hash__(self):
"""Custom hash for caching"""
return hash(tuple(
getattr(self, field.name)
for field in dataclasses.fields(self)
))
@dataclass
class BwdParams:
"""Backward pass configuration parameters"""
# Typically smaller block sizes for memory efficiency
blocksize_m: int | None = None
blocksize_k: int | None = None
blocksize_n: int | None = None
q_blocksize: int | None = None
kv_blocksize: int | None = None
num_warps: int | None = None
num_stages: int | None = None
def __hash__(self):
"""Custom hash for caching"""
return hash(tuple(
getattr(self, field.name)
for field in dataclasses.fields(self)
))
Context Managers#
Policy Override#
@contextmanager
def policy_override(selector: ConfigSelectorChain,
allow_autotune: bool | None = None,
allow_heuristics: bool | None = None,
cache_miss_fallback: str | None = None):
"""Temporarily override autotuning policy"""
old_policy = selector.policy
new_policy = AutotunePolicy(
allow_autotune=allow_autotune if allow_autotune is not None
else old_policy.allow_autotune,
allow_heuristics=allow_heuristics if allow_heuristics is not None
else old_policy.allow_heuristics,
cache_miss_fallback=cache_miss_fallback or old_policy.cache_miss_fallback
)
selector.policy = new_policy
try:
yield
finally:
selector.policy = old_policy
Cache Overlay#
@contextmanager
def overlay_cache(overrides: dict):
"""Temporarily override cache entries"""
old_overlay = selector.overlay
selector.overlay = {**old_overlay, **overrides}
try:
yield
finally:
selector.overlay = old_overlay
Design Patterns#
1. Multi-Tier Configuration Selection#
Pattern: Hierarchical fallback with caching at multiple levels
Benefits:
Performance: Fast path for cached configurations
Flexibility: Multiple ways to provide configurations
Robustness: Always has a fallback (heuristics)
2. Device-Aware Caching#
Pattern: Cache key includes device fingerprint
Benefits:
Correctness: Device-specific optimizations
Performance: Optimal configuration per device
Portability: Works across heterogeneous hardware
3. Atomic Persistence#
Pattern: Write to temp file + atomic replace
Benefits:
Safety: No partial writes or corruption
Concurrency: Multiple processes can safely read/write
Reliability: Filesystem-level guarantees
4. Custom VJP Integration#
Pattern: Extract array leaves, rebuild in fwd/bwd
Benefits:
Efficiency: Specialized gradient computation
Correctness: Proper handling of constants
Compatibility: Works with JAX transformations
Performance Considerations#
Configuration Selection#
First call: May trigger autotuning (seconds)
Subsequent calls: Cache lookup (microseconds)
Persistent cache: Survives process restart (milliseconds to load)
Autotuning#
Warmup runs: Eliminates JIT compilation overhead
Multiple iterations: Reduces measurement noise
Parallel testing: Can test multiple configs concurrently
Early stopping: Can stop when clear winner emerges
Memory Management#
In-memory cache: Unbounded growth (consider LRU eviction)
Persistent cache: JSON file size grows with entries
Configuration objects: Lightweight dataclasses
Testing and Debugging#
Force Specific Configuration#
# Direct override
output = executor(kernel, *args, _cfg=my_config)
# Via overlay cache
with overlay_cache({cache_key: my_config}):
output = executor(kernel, *args)
Disable Autotuning#
with policy_override(selector, allow_autotune=False):
output = executor(kernel, *args)
Inspect Cache#
# In-memory
print(f"Cached configs: {cache.size()}")
for key, cfg in cache._cache.items():
print(f"{key}: {cfg}")
# Persistent
with open(persistent.path) as f:
data = json.load(f)
print(f"Persistent configs: {len(data)}")
Conclusion#
The ops system provides a sophisticated, production-ready infrastructure for kernel execution with:
Flexibility: Multiple configuration sources with fallback
Performance: Automatic optimization with caching
Extensibility: Easy to add new kernels and configurations
Robustness: Error handling and graceful degradation
Debugging: Rich context and inspection capabilities
This architecture enables ejKernel to deliver optimal performance across diverse hardware while maintaining simplicity for end users.