# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core kernel infrastructure for configurable JAX operations.
This module provides the foundational classes for implementing high-performance
JAX operations with automatic configuration optimization and caching.
Key Classes:
Invocation: Represents a specific call to a kernel with arguments and metadata
Kernel: Abstract base class for implementing configurable operations
The kernel system enables:
- Automatic hyperparameter optimization through configuration testing
- Caching of optimal configurations for performance
- Custom gradient implementations with VJP support
- Flexible argument preprocessing and transformation
- Device-aware configuration management
Kernel Implementation Pattern:
1. Inherit from Kernel[ConfigType, OutputType]
2. Implement run() method for the core operation
3. Implement heuristic_cfg() for default configuration
4. Optionally implement candidate_cfgs() for autotuning
5. Optionally implement custom VJP methods for gradients
Example Implementation:
>>> @dataclass
>>> class MatMulConfig:
... precision: str = 'default'
... transpose_a: bool = False
>>>
>>> class MatMulKernel(Kernel[MatMulConfig, jax.Array]):
... def run(self, a, b, cfg: MatMulConfig) -> jax.Array:
... return jnp.dot(a, b, precision=cfg.precision)
...
... def heuristic_cfg(self, inv) -> MatMulConfig:
... return MatMulConfig()
...
... def candidate_cfgs(self, inv):
... return [MatMulConfig(p) for p in ['float32', 'bfloat16']]
Invocation Usage:
The Invocation class captures all information needed for a kernel call:
- Arguments and their shapes/types
- Optional configuration overrides
- Batching information for vmapping
- Profiling and caching metadata
This design enables seamless integration with the autotuning and caching
system while providing a clean interface for operation implementers.
"""
from __future__ import annotations
import dataclasses
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Generic
import jax
import jax.sharding
from jax import shard_map
from ..utils.fingerprint import abstractify, short_hash
from .types import Cfg, Out
[docs]@dataclasses.dataclass(frozen=True)
class Invocation(Generic[Cfg, Out]):
"""Represents a specific call to a kernel with arguments and metadata.
This dataclass captures all the information needed to execute a kernel,
including arguments, configuration overrides, and execution metadata.
Attributes:
op_id: Unique identifier for the operation
args: Positional arguments for the kernel
kwargs: Keyword arguments for the kernel
batch_axes: Optional mapping of parameter names to batch axes for vmapping
override_cfg: Optional configuration to use instead of cached/computed ones
stamp: Whether to add profiling metadata to the operation
method: Execution method (e.g., "shard_map" or None for standard)
mesh: JAX mesh for shard_map execution
in_specs: Input partition specs for shard_map
out_specs: Output partition spec for shard_map
check_vma: Whether to check for valid memory access in shard_map
"""
op_id: str
args: tuple[Any, ...]
kwargs: Mapping[str, Any]
batch_axes: Mapping[str, int] | None = None
override_cfg: Cfg | None = None
stamp: bool = True
method: str | None = None
mesh: jax.sharding.Mesh | None = None
in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None
out_specs: jax.sharding.PartitionSpec | None = None
check_vma: bool = False
@property
def call_key(self) -> str:
"""Generate a stable hash key for this invocation based on argument shapes and types.
Creates a 16-character hash that uniquely identifies this invocation
based on the abstract shapes and types of arguments, not their values.
This enables caching of configurations based on operation signature.
Returns:
16-character hexadecimal hash string representing the call signature
Note:
The hash includes argument shapes/types, keyword argument shapes/types,
batch axes information, and execution method. Array values are not included,
allowing the same configuration to be reused for arrays with the same structure.
"""
spec = dict(
args_spec=abstractify(self.args),
kwargs_spec=abstractify(dict(self.kwargs)),
batch_axes=self.batch_axes,
method=self.method,
)
return short_hash(spec)
[docs] def make_key(self, key_builder=None) -> str:
"""Generate a cache key for this invocation, optionally using a custom key builder.
Provides flexibility in cache key generation by allowing custom key builders
while falling back to the default implementation.
Args:
key_builder: Optional function that takes an Invocation and returns a key
Returns:
Cache key string for this invocation
Note:
Custom key builders can include additional information like sharding
or device placement for more sophisticated caching strategies.
"""
if key_builder is not None:
return key_builder(self)
spec = dict(
args_spec=abstractify(self.args),
kwargs_spec=abstractify(dict(self.kwargs)),
batch_axes=self.batch_axes,
method=self.method,
)
return short_hash(spec)
[docs]class Kernel(Generic[Cfg, Out]):
"""Abstract base class for implementing custom JAX operations with configuration management.
A Kernel encapsulates the logic for a specific operation, including how to execute it
with different configurations, what configurations are available, and optionally how
to compute custom gradients.
Required methods to implement:
run: Execute the operation with a given configuration
heuristic_cfg: Provide a reasonable default configuration
Optional methods:
prepare: Preprocess arguments before execution
candidate_cfgs: Provide alternative configurations for autotuning
fwd_with_residuals: Forward pass with residuals for custom VJP
vjp: Backward pass for custom VJP
run_shard_map: Specialized execution for shard_map contexts
fwd_with_residuals_shard_map: Forward pass with residuals for shard_map
vjp_shard_map: Backward pass for shard_map
Method Naming Convention:
- Platform-specific: {method}_{platform} (e.g., run_gpu, run_tpu)
- Context-specific: {method}_{context} (e.g., run_shard_map)
- Composite: {method}_{context}_{platform} (e.g., run_shard_map_gpu)
Platforms: 'gpu', 'tpu', 'cpu' (hardware backends)
Contexts: 'shard_map' (execution modes/environments)
Priority: composite > context > platform > generic
Attributes:
op_id: Unique identifier for this operation
key_builder: Optional custom function to generate cache keys
version: Version string for cache invalidation
"""
op_id: str
key_builder: Callable[[Invocation[Cfg, Out]], str] | None = None
version: str = "0"
def __init__(self, op_id: str | None = None):
if op_id is not None:
self.op_id = op_id
elif getattr(self, "op_id", None):
pass
else:
self.op_id = f"{type(self).__module__}.{type(self).__name__}"
[docs] def prepare(self, *args, **kwargs) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""Preprocess arguments before execution. Override to modify args/kwargs.
This method is called before the run() method to allow transformation
of arguments. Common use cases include shape validation, type conversion,
or argument reordering.
Args:
*args: Positional arguments to preprocess
**kwargs: Keyword arguments to preprocess
Returns:
Tuple of (processed_args, processed_kwargs)
Example:
>>> def prepare(self, x, y, **kwargs):
...
... x = jnp.asarray(x)
... y = jnp.asarray(y)
... return (x, y), kwargs
"""
return args, kwargs
[docs] def run(self, *args, cfg: Cfg, **kwargs) -> Out:
"""Execute the operation with the given configuration. Must be implemented.
This is the core method that performs the actual computation. It receives
the preprocessed arguments and a configuration object, and must return
the operation result.
Args:
*args: Positional arguments (after prepare() preprocessing)
cfg: Configuration object specifying how to execute the operation
**kwargs: Keyword arguments (after prepare() preprocessing)
Returns:
Result of the operation
Raises:
NotImplementedError: Must be overridden in subclasses
Example:
>>> def run(self, x, y, cfg: MatMulConfig) -> jax.Array:
... if cfg.transpose_a:
... x = x.T
... return jnp.dot(x, y, precision=cfg.precision)
"""
raise NotImplementedError
[docs] def heuristic_cfg(self, inv: Invocation[Cfg, Out]) -> Cfg:
"""Return a reasonable default configuration for this invocation. Must be implemented.
Provides a sensible default configuration based on the invocation context.
This configuration should work correctly for the given arguments, though
it may not be optimal for performance.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Default configuration object
Raises:
NotImplementedError: Must be overridden in subclasses
Example:
>>> def heuristic_cfg(self, inv) -> MatMulConfig:
...
... dtype = inv.args[0].dtype
... precision = 'bfloat16' if dtype == jnp.bfloat16 else 'float32'
... return MatMulConfig(precision=precision)
"""
raise NotImplementedError
[docs] def candidate_cfgs(self, inv: Invocation[Cfg, Out]) -> Iterable[Cfg]:
"""Return alternative configurations for autotuning. Defaults to just heuristic_cfg.
Provides a set of configurations to test during autotuning. The autotuning
system will benchmark each configuration and select the fastest one.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Iterable of configuration objects to test
Example:
>>> def candidate_cfgs(self, inv):
... return [
... MatMulConfig(precision='float32', transpose_a=False),
... MatMulConfig(precision='bfloat16', transpose_a=False),
... MatMulConfig(precision='float32', transpose_a=True),
... ]
Note:
The default implementation returns only the heuristic configuration.
Override this method to enable autotuning with multiple options.
"""
return [self.heuristic_cfg(inv)]
[docs] def fwd_with_residuals(self, *args, cfg: Cfg, **kwargs) -> tuple[Out, Any]:
"""Forward pass that returns residuals for custom VJP. Implement for custom gradients.
When implementing custom gradients, this method performs the forward pass
and returns both the result and any residual values needed for the
backward pass.
Args:
*args: Positional arguments for the operation
cfg: Configuration object
**kwargs: Keyword arguments for the operation
Returns:
Tuple of (operation_result, residuals)
- operation_result: Same as run() method output
- residuals: Any values needed for the backward pass
Raises:
NotImplementedError: Only implement if providing custom gradients
Example:
>>> def fwd_with_residuals(self, x, y, cfg):
... result = jnp.dot(x, y)
... residuals = (x, y, cfg)
... return result, residuals
Note:
Must be implemented together with vjp() method for custom gradients.
"""
raise NotImplementedError
[docs] def vjp(self, residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs):
"""Backward pass for custom VJP. Return gradients for positional args only.
Computes vector-Jacobian products (gradients) for the custom operation.
This method is called during backpropagation to compute gradients with
respect to the positional arguments.
Args:
residuals: Values returned from fwd_with_residuals()
y: Forward pass output (from fwd_with_residuals())
dy: Incoming gradients (cotangents) with respect to y
*args: Original positional arguments
cfg: Configuration object
**kwargs: Original keyword arguments
Returns:
Tuple of gradients for each positional argument
(None for arguments that don't need gradients)
Raises:
NotImplementedError: Only implement if providing custom gradients
Example:
>>> def vjp(self, residuals, y, dy, *args, cfg, **kwargs):
... x, y_orig, _ = residuals
... dx = jnp.dot(dy, y_orig.T)
... dy_orig = jnp.dot(x.T, dy)
... return dx, dy_orig
Note:
Must be implemented together with fwd_with_residuals() method.
"""
raise NotImplementedError
[docs] def run_shard_map(self, *args, cfg: Cfg, **kwargs) -> Out:
"""Execute the operation within a shard_map context. Optional override.
This method can be implemented to provide a specialized version of the
operation that runs efficiently within JAX's shard_map context. If not
implemented, the regular run() method will be used as fallback.
Args:
*args: Positional arguments (after prepare() preprocessing)
cfg: Configuration object specifying how to execute the operation
**kwargs: Keyword arguments (after prepare() preprocessing)
Returns:
Result of the operation
Raises:
NotImplementedError: Only implement if providing shard_map-specific execution
Example:
>>> def run_shard_map(self, x, y, cfg: MatMulConfig) -> jax.Array:
...
... return custom_sharded_matmul(x, y, cfg)
Note:
This method is automatically detected by _get_platform_method() and
used when 'shard_map' is specified as the platform.
"""
raise NotImplementedError
[docs] def fwd_with_residuals_shard_map(self, *args, cfg: Cfg, **kwargs) -> tuple[Out, Any]:
"""Forward pass with residuals for shard_map context. Optional override.
Specialized forward pass for shard_map contexts that returns both the
result and residuals needed for the backward pass.
Args:
*args: Positional arguments for the operation
cfg: Configuration object
**kwargs: Keyword arguments for the operation
Returns:
Tuple of (operation_result, residuals)
- operation_result: Same as run_shard_map() method output
- residuals: Any values needed for the backward pass
Raises:
NotImplementedError: Only implement if providing custom gradients for shard_map
Note:
Must be implemented together with vjp_shard_map() method for custom
gradients in shard_map contexts.
"""
raise NotImplementedError
[docs] def vjp_shard_map(self, residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs):
"""Backward pass for shard_map context. Optional override.
Specialized backward pass for shard_map contexts that computes gradients
with respect to positional arguments.
Args:
residuals: Values returned from fwd_with_residuals_shard_map()
y: Forward pass output (from fwd_with_residuals_shard_map())
dy: Incoming gradients (cotangents) with respect to y
*args: Original positional arguments
cfg: Configuration object
**kwargs: Original keyword arguments
Returns:
Tuple of gradients for each positional argument
(None for arguments that don't need gradients)
Raises:
NotImplementedError: Only implement if providing custom gradients for shard_map
Note:
Must be implemented together with fwd_with_residuals_shard_map() method.
"""
raise NotImplementedError
[docs] def run_shard_map_gpu(self, *args, cfg: Cfg, **kwargs) -> Out:
"""Execute operation in shard_map context on GPU. Optional override.
Most specific implementation combining shard_map execution context with
GPU platform optimizations.
Example:
>>> def run_shard_map_gpu(self, x, y, cfg: MyConfig) -> jax.Array:
...
... return gpu_sharded_operation(x, y, cfg)
"""
raise NotImplementedError
[docs] def fwd_with_residuals_shard_map_gpu(self, *args, cfg: Cfg, **kwargs) -> tuple[Out, Any]:
"""Forward pass with residuals for shard_map on GPU. Optional override."""
raise NotImplementedError
[docs] def vjp_shard_map_gpu(self, residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs):
"""Backward pass for shard_map on GPU. Optional override."""
raise NotImplementedError
[docs] def create_shard_map_wrapper(
self,
fn: Callable,
*args,
mesh: jax.sharding.Mesh,
in_specs: tuple[jax.sharding.PartitionSpec, ...],
out_specs: jax.sharding.PartitionSpec,
check_vma: bool = False,
**kwargs,
) -> tuple[Callable, tuple]:
"""Create a shard_map wrapper around a function with fixed kwargs.
This helper method simplifies the pattern of wrapping functions with shard_map
by handling the functools.partial and shard_map setup automatically.
Args:
fn: Function to wrap with shard_map (e.g., flash_attention)
*args: Positional arguments that will be passed through shard_map
mesh: JAX device mesh for distributed execution
in_specs: Partition specs for input arguments (must match len(args))
out_specs: Partition spec for output
check_vma: Whether to check replication in shard_map
**kwargs: Keyword arguments to fix via functools.partial
Returns:
Tuple of (shard_map_wrapped_function, call_arguments)
- shard_map_wrapped_function: Function that takes positional args
- call_arguments: Tuple of args to pass to the wrapped function
Example:
>>> from ejkernel.modules.operations import flash_attention
>>> attn = FlashAttention()
>>>
>>>
>>> shard_map_fn, call_args = attn.create_shard_map_wrapper(
... flash_attention,
... query, key, value,
... mesh=my_mesh,
... in_specs=(q_spec, k_spec, v_spec),
... out_specs=out_spec,
... causal=True,
... softmax_scale=0.125
... )
>>>
>>>
>>> output = shard_map_fn(*call_args)
Note:
This follows the EasyDeL pattern where attention operations are wrapped
with shard_map for distributed execution. The function `fn` is called
with the positional args and fixed kwargs inside the shard_map context.
"""
def _wrapped(*call_args):
return fn(*call_args, **kwargs)
shard_map_fn = shard_map(
_wrapped,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
return shard_map_fn, args
def _has_custom_vjp(
k: Kernel,
platform: str | None = None,
context: str | None = None,
) -> bool:
"""Check if a kernel has implemented custom VJP (vector-Jacobian product) methods.
Returns True if both fwd_with_residuals and vjp methods have been overridden
from the base Kernel class. Supports context and platform-specific methods
(e.g., fwd_with_residuals_shard_map_gpu).
Args:
k: Kernel instance to check
platform: Optional platform identifier (e.g., 'gpu', 'tpu', 'cpu')
context: Optional execution context (e.g., 'shard_map')
Returns:
True if kernel has custom VJP implementation (generic, context-specific,
platform-specific, or composite)
"""
try:
if context and platform:
composite_fwd = f"fwd_with_residuals_{context}_{platform}"
composite_vjp = f"vjp_{context}_{platform}"
has_composite_fwd = hasattr(type(k), composite_fwd) and getattr(type(k), composite_fwd) is not getattr(
Kernel, composite_fwd, None
)
has_composite_vjp = hasattr(type(k), composite_vjp) and getattr(type(k), composite_vjp) is not getattr(
Kernel, composite_vjp, None
)
if has_composite_fwd and has_composite_vjp:
return True
if context:
context_fwd = f"fwd_with_residuals_{context}"
context_vjp = f"vjp_{context}"
has_context_fwd = hasattr(type(k), context_fwd) and getattr(type(k), context_fwd) is not getattr(
Kernel, context_fwd, None
)
has_context_vjp = hasattr(type(k), context_vjp) and getattr(type(k), context_vjp) is not getattr(
Kernel, context_vjp, None
)
if has_context_fwd and has_context_vjp:
return True
if platform:
platform_fwd = f"fwd_with_residuals_{platform}"
platform_vjp = f"vjp_{platform}"
has_platform_fwd = hasattr(type(k), platform_fwd) and getattr(type(k), platform_fwd) is not getattr(
Kernel, platform_fwd, None
)
has_platform_vjp = hasattr(type(k), platform_vjp) and getattr(type(k), platform_vjp) is not getattr(
Kernel, platform_vjp, None
)
if has_platform_fwd and has_platform_vjp:
return True
return type(k).fwd_with_residuals is not Kernel.fwd_with_residuals and type(k).vjp is not Kernel.vjp
except AttributeError:
return False
def _get_platform_method(
k: Kernel,
method_name: str,
platform: str | None = None,
context: str | None = None,
) -> Callable | None:
"""Get context and platform-specific method from kernel, with fallback hierarchy.
Supports execution contexts (like 'shard_map') combined with hardware platforms
(like 'gpu', 'tpu', 'cpu'). The lookup follows this priority:
1. {method}_{context}_{platform} (e.g., run_shard_map_gpu)
2. {method}_{context} (e.g., run_shard_map)
3. {method}_{platform} (e.g., run_gpu)
4. {method} (e.g., run)
Args:
k: Kernel instance
method_name: Base method name (e.g., 'run', 'candidate_cfgs', 'fwd_with_residuals')
platform: Optional platform identifier (e.g., 'gpu', 'tpu', 'cpu')
context: Optional execution context (e.g., 'shard_map')
Returns:
Most specific available method, or None if no override exists
Example:
>>>
>>> method = _get_platform_method(kernel, 'run', platform='gpu', context='shard_map')
"""
if context and platform:
name = f"{method_name}_{context}_{platform}"
if hasattr(k, name):
method = getattr(k, name)
base = getattr(Kernel, name, None)
if method is not base:
return method
if context:
name = f"{method_name}_{context}"
if hasattr(k, name):
method = getattr(k, name)
base = getattr(Kernel, name, None)
if method is not base:
return method
if platform:
name = f"{method_name}_{platform}"
if hasattr(k, name):
method = getattr(k, name)
base = getattr(Kernel, name, None)
if method is not base:
return method
if hasattr(k, method_name):
method = getattr(k, method_name)
base = getattr(Kernel, method_name, None)
if method is not base:
return method
return None