ejkernel.ops.core.kernel#

Core kernel infrastructure for configurable JAX operations.

This module provides the foundational classes for implementing high-performance JAX operations with automatic configuration optimization and caching.

Key Classes:

Invocation: Represents a specific call to a kernel with arguments and metadata Kernel: Abstract base class for implementing configurable operations

The kernel system enables:
  • Automatic hyperparameter optimization through configuration testing

  • Caching of optimal configurations for performance

  • Custom gradient implementations with VJP support

  • Flexible argument preprocessing and transformation

  • Device-aware configuration management

Kernel Implementation Pattern:
  1. Inherit from Kernel[ConfigType, OutputType]

  2. Implement run() method for the core operation

  3. Implement heuristic_cfg() for default configuration

  4. Optionally implement candidate_cfgs() for autotuning

  5. Optionally implement custom VJP methods for gradients

Example Implementation:
>>> @dataclass
>>> class MatMulConfig:
...     precision: str = 'default'
...     transpose_a: bool = False
>>>
>>> class MatMulKernel(Kernel[MatMulConfig, jax.Array]):
...     def run(self, a, b, cfg: MatMulConfig) -> jax.Array:
...         return jnp.dot(a, b, precision=cfg.precision)
...
...     def heuristic_cfg(self, inv) -> MatMulConfig:
...         return MatMulConfig()
...
...     def candidate_cfgs(self, inv):
...         return [MatMulConfig(p) for p in ['float32', 'bfloat16']]
Invocation Usage:

The Invocation class captures all information needed for a kernel call: - Arguments and their shapes/types - Optional configuration overrides - Batching information for vmapping - Profiling and caching metadata

This design enables seamless integration with the autotuning and caching system while providing a clean interface for operation implementers.

class ejkernel.ops.core.kernel.Invocation(op_id: str, args: tuple[Any, ...], kwargs: Mapping[str, Any], batch_axes: collections.abc.Mapping[str, int] | None = None, override_cfg: Optional[Cfg] = None, stamp: bool = True, method: str | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Bases: Generic[Cfg, Out]

Represents a specific call to a kernel with arguments and metadata.

This dataclass captures all the information needed to execute a kernel, including arguments, configuration overrides, and execution metadata.

op_id#

Unique identifier for the operation

Type

str

args#

Positional arguments for the kernel

Type

tuple[Any, …]

kwargs#

Keyword arguments for the kernel

Type

collections.abc.Mapping[str, Any]

batch_axes#

Optional mapping of parameter names to batch axes for vmapping

Type

collections.abc.Mapping[str, int] | None

override_cfg#

Optional configuration to use instead of cached/computed ones

Type

Optional[ejkernel.ops.core.types.Cfg]

stamp#

Whether to add profiling metadata to the operation

Type

bool

method#

Execution method (e.g., “shard_map” or None for standard)

Type

str | None

mesh#

JAX mesh for shard_map execution

Type

jax._src.mesh.Mesh | None

in_specs#

Input partition specs for shard_map

Type

tuple[jax.sharding.PartitionSpec, …] | None

out_specs#

Output partition spec for shard_map

Type

jax.sharding.PartitionSpec | None

check_vma#

Whether to check for valid memory access in shard_map

Type

bool

args: tuple[Any, ...]#
batch_axes: collections.abc.Mapping[str, int] | None = None#
property call_key: str#

Generate a stable hash key for this invocation based on argument shapes and types.

Creates a 16-character hash that uniquely identifies this invocation based on the abstract shapes and types of arguments, not their values. This enables caching of configurations based on operation signature.

Returns

16-character hexadecimal hash string representing the call signature

Note

The hash includes argument shapes/types, keyword argument shapes/types, batch axes information, and execution method. Array values are not included, allowing the same configuration to be reused for arrays with the same structure.

check_vma: bool = False#
in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None#
kwargs: Mapping[str, Any]#
make_key(key_builder=None) str[source]#

Generate a cache key for this invocation, optionally using a custom key builder.

Provides flexibility in cache key generation by allowing custom key builders while falling back to the default implementation.

Parameters

key_builder – Optional function that takes an Invocation and returns a key

Returns

Cache key string for this invocation

Note

Custom key builders can include additional information like sharding or device placement for more sophisticated caching strategies.

mesh: jax._src.mesh.Mesh | None = None#
method: str | None = None#
op_id: str#
out_specs: jax.sharding.PartitionSpec | None = None#
override_cfg: Optional[Cfg] = None#
stamp: bool = True#
class ejkernel.ops.core.kernel.Kernel(op_id: str | None = None)[source]#

Bases: Generic[Cfg, Out]

Abstract base class for implementing custom JAX operations with configuration management.

A Kernel encapsulates the logic for a specific operation, including how to execute it with different configurations, what configurations are available, and optionally how to compute custom gradients.

Required methods to implement:

run: Execute the operation with a given configuration heuristic_cfg: Provide a reasonable default configuration

Optional methods:

prepare: Preprocess arguments before execution candidate_cfgs: Provide alternative configurations for autotuning fwd_with_residuals: Forward pass with residuals for custom VJP vjp: Backward pass for custom VJP run_shard_map: Specialized execution for shard_map contexts fwd_with_residuals_shard_map: Forward pass with residuals for shard_map vjp_shard_map: Backward pass for shard_map

Method Naming Convention:
  • Platform-specific: {method}_{platform} (e.g., run_gpu, run_tpu)

  • Context-specific: {method}_{context} (e.g., run_shard_map)

  • Composite: {method}_{context}_{platform} (e.g., run_shard_map_gpu)

Platforms: ‘gpu’, ‘tpu’, ‘cpu’ (hardware backends) Contexts: ‘shard_map’ (execution modes/environments)

Priority: composite > context > platform > generic

op_id#

Unique identifier for this operation

Type

str

key_builder#

Optional custom function to generate cache keys

Type

collections.abc.Callable[[ejkernel.ops.core.kernel.Invocation[ejkernel.ops.core.types.Cfg, ejkernel.ops.core.types.Out]], str] | None

version#

Version string for cache invalidation

Type

str

candidate_cfgs(inv: Invocation[Cfg, Out]) Iterable[Cfg][source]#

Return alternative configurations for autotuning. Defaults to just heuristic_cfg.

Provides a set of configurations to test during autotuning. The autotuning system will benchmark each configuration and select the fastest one.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Iterable of configuration objects to test

Example

>>> def candidate_cfgs(self, inv):
...     return [
...         MatMulConfig(precision='float32', transpose_a=False),
...         MatMulConfig(precision='bfloat16', transpose_a=False),
...         MatMulConfig(precision='float32', transpose_a=True),
...     ]

Note

The default implementation returns only the heuristic configuration. Override this method to enable autotuning with multiple options.

create_shard_map_wrapper(fn: Callable, *args, mesh: Mesh, in_specs: tuple[jax.sharding.PartitionSpec, ...], out_specs: PartitionSpec, check_vma: bool = False, **kwargs) tuple[collections.abc.Callable, tuple][source]#

Create a shard_map wrapper around a function with fixed kwargs.

This helper method simplifies the pattern of wrapping functions with shard_map by handling the functools.partial and shard_map setup automatically.

Parameters
  • fn – Function to wrap with shard_map (e.g., flash_attention)

  • *args – Positional arguments that will be passed through shard_map

  • mesh – JAX device mesh for distributed execution

  • in_specs – Partition specs for input arguments (must match len(args))

  • out_specs – Partition spec for output

  • check_vma – Whether to check replication in shard_map

  • **kwargs – Keyword arguments to fix via functools.partial

Returns

Tuple of (shard_map_wrapped_function, call_arguments) - shard_map_wrapped_function: Function that takes positional args - call_arguments: Tuple of args to pass to the wrapped function

Example

>>> from ejkernel.modules.operations import flash_attention
>>> attn = FlashAttention()
>>>
>>>
>>> shard_map_fn, call_args = attn.create_shard_map_wrapper(
...     flash_attention,
...     query, key, value,
...     mesh=my_mesh,
...     in_specs=(q_spec, k_spec, v_spec),
...     out_specs=out_spec,
...     causal=True,
...     softmax_scale=0.125
... )
>>>
>>>
>>> output = shard_map_fn(*call_args)

Note

This follows the EasyDeL pattern where attention operations are wrapped with shard_map for distributed execution. The function fn is called with the positional args and fixed kwargs inside the shard_map context.

fwd_with_residuals(*args, cfg: Cfg, **kwargs) tuple[Out, Any][source]#

Forward pass that returns residuals for custom VJP. Implement for custom gradients.

When implementing custom gradients, this method performs the forward pass and returns both the result and any residual values needed for the backward pass.

Parameters
  • *args – Positional arguments for the operation

  • cfg – Configuration object

  • **kwargs – Keyword arguments for the operation

Returns

Tuple of (operation_result, residuals) - operation_result: Same as run() method output - residuals: Any values needed for the backward pass

Raises

NotImplementedError – Only implement if providing custom gradients

Example

>>> def fwd_with_residuals(self, x, y, cfg):
...     result = jnp.dot(x, y)
...     residuals = (x, y, cfg)
...     return result, residuals

Note

Must be implemented together with vjp() method for custom gradients.

fwd_with_residuals_shard_map(*args, cfg: Cfg, **kwargs) tuple[Out, Any][source]#

Forward pass with residuals for shard_map context. Optional override.

Specialized forward pass for shard_map contexts that returns both the result and residuals needed for the backward pass.

Parameters
  • *args – Positional arguments for the operation

  • cfg – Configuration object

  • **kwargs – Keyword arguments for the operation

Returns

Tuple of (operation_result, residuals) - operation_result: Same as run_shard_map() method output - residuals: Any values needed for the backward pass

Raises

NotImplementedError – Only implement if providing custom gradients for shard_map

Note

Must be implemented together with vjp_shard_map() method for custom gradients in shard_map contexts.

fwd_with_residuals_shard_map_gpu(*args, cfg: Cfg, **kwargs) tuple[Out, Any][source]#

Forward pass with residuals for shard_map on GPU. Optional override.

heuristic_cfg(inv: Invocation[Cfg, Out]) Cfg[source]#

Return a reasonable default configuration for this invocation. Must be implemented.

Provides a sensible default configuration based on the invocation context. This configuration should work correctly for the given arguments, though it may not be optimal for performance.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration object

Raises

NotImplementedError – Must be overridden in subclasses

Example

>>> def heuristic_cfg(self, inv) -> MatMulConfig:
...
...     dtype = inv.args[0].dtype
...     precision = 'bfloat16' if dtype == jnp.bfloat16 else 'float32'
...     return MatMulConfig(precision=precision)
key_builder: collections.abc.Callable[[ejkernel.ops.core.kernel.Invocation[ejkernel.ops.core.types.Cfg, ejkernel.ops.core.types.Out]], str] | None = None#
op_id: str#
prepare(*args, **kwargs) tuple[tuple[Any, ...], dict[str, Any]][source]#

Preprocess arguments before execution. Override to modify args/kwargs.

This method is called before the run() method to allow transformation of arguments. Common use cases include shape validation, type conversion, or argument reordering.

Parameters
  • *args – Positional arguments to preprocess

  • **kwargs – Keyword arguments to preprocess

Returns

Tuple of (processed_args, processed_kwargs)

Example

>>> def prepare(self, x, y, **kwargs):
...
...     x = jnp.asarray(x)
...     y = jnp.asarray(y)
...     return (x, y), kwargs
run(*args, cfg: Cfg, **kwargs) Out[source]#

Execute the operation with the given configuration. Must be implemented.

This is the core method that performs the actual computation. It receives the preprocessed arguments and a configuration object, and must return the operation result.

Parameters
  • *args – Positional arguments (after prepare() preprocessing)

  • cfg – Configuration object specifying how to execute the operation

  • **kwargs – Keyword arguments (after prepare() preprocessing)

Returns

Result of the operation

Raises

NotImplementedError – Must be overridden in subclasses

Example

>>> def run(self, x, y, cfg: MatMulConfig) -> jax.Array:
...     if cfg.transpose_a:
...         x = x.T
...     return jnp.dot(x, y, precision=cfg.precision)
run_shard_map(*args, cfg: Cfg, **kwargs) Out[source]#

Execute the operation within a shard_map context. Optional override.

This method can be implemented to provide a specialized version of the operation that runs efficiently within JAX’s shard_map context. If not implemented, the regular run() method will be used as fallback.

Parameters
  • *args – Positional arguments (after prepare() preprocessing)

  • cfg – Configuration object specifying how to execute the operation

  • **kwargs – Keyword arguments (after prepare() preprocessing)

Returns

Result of the operation

Raises

NotImplementedError – Only implement if providing shard_map-specific execution

Example

>>> def run_shard_map(self, x, y, cfg: MatMulConfig) -> jax.Array:
...
...     return custom_sharded_matmul(x, y, cfg)

Note

This method is automatically detected by _get_platform_method() and used when ‘shard_map’ is specified as the platform.

run_shard_map_gpu(*args, cfg: Cfg, **kwargs) Out[source]#

Execute operation in shard_map context on GPU. Optional override.

Most specific implementation combining shard_map execution context with GPU platform optimizations.

Example

>>> def run_shard_map_gpu(self, x, y, cfg: MyConfig) -> jax.Array:
...
...     return gpu_sharded_operation(x, y, cfg)
version: str = '0'#
vjp(residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs)[source]#

Backward pass for custom VJP. Return gradients for positional args only.

Computes vector-Jacobian products (gradients) for the custom operation. This method is called during backpropagation to compute gradients with respect to the positional arguments.

Parameters
  • residuals – Values returned from fwd_with_residuals()

  • y – Forward pass output (from fwd_with_residuals())

  • dy – Incoming gradients (cotangents) with respect to y

  • *args – Original positional arguments

  • cfg – Configuration object

  • **kwargs – Original keyword arguments

Returns

Tuple of gradients for each positional argument (None for arguments that don’t need gradients)

Raises

NotImplementedError – Only implement if providing custom gradients

Example

>>> def vjp(self, residuals, y, dy, *args, cfg, **kwargs):
...     x, y_orig, _ = residuals
...     dx = jnp.dot(dy, y_orig.T)
...     dy_orig = jnp.dot(x.T, dy)
...     return dx, dy_orig

Note

Must be implemented together with fwd_with_residuals() method.

vjp_shard_map(residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs)[source]#

Backward pass for shard_map context. Optional override.

Specialized backward pass for shard_map contexts that computes gradients with respect to positional arguments.

Parameters
  • residuals – Values returned from fwd_with_residuals_shard_map()

  • y – Forward pass output (from fwd_with_residuals_shard_map())

  • dy – Incoming gradients (cotangents) with respect to y

  • *args – Original positional arguments

  • cfg – Configuration object

  • **kwargs – Original keyword arguments

Returns

Tuple of gradients for each positional argument (None for arguments that don’t need gradients)

Raises

NotImplementedError – Only implement if providing custom gradients for shard_map

Note

Must be implemented together with fwd_with_residuals_shard_map() method.

vjp_shard_map_gpu(residuals: Any, y: Out, dy: Out, *args, cfg: Cfg, **kwargs)[source]#

Backward pass for shard_map on GPU. Optional override.