ejkernel.callib._ejit#

Compilation utilities for JAX function optimization.

Provides enhanced JIT compilation with persistent caching to disk, reducing compilation overhead across script runs.

Functions:

ejit: Enhanced JIT with persistent caching save_compiled_fn: Save compiled function to disk load_compiled_fn: Load compiled function from disk load_cached_functions: Load multiple cached functions smart_compile: Smart compilation with auto-caching hash_fn: Generate hash for function signature

Constants:

RECOMPILE_FORCE: Force recompilation flag ECACHE_COMPILES: Enable compilation caching CACHE_DIR: Cache directory path COMPILE_FUNC_DIR: Compiled functions directory COMPILED_CACHE: In-memory cache of compiled functions

Key Features:
  • Persistent disk caching of compiled functions

  • Automatic cache invalidation on changes

  • Hardware-specific signatures

  • Two-level caching (memory + disk)

  • Graceful fallback on errors

Example

>>> from easydel.utils.compiling_utils import ejit
>>>
>>> @ejit
... def optimized_fn(x, y):
...     return x @ y + x.T @ y.T
>>>
>>>
>>> result = optimized_fn(a, b)
>>>
>>> result = optimized_fn(a, b)
ejkernel.callib._ejit.ejit(func: Optional[Callable[[P], R]] = None, *, static_argnums: Optional[Union[int, Sequence[int]]] = None, static_argnames: Optional[Union[str, Iterable[str]]] = None, donate_argnums: Optional[Union[int, Sequence[int]]] = None, in_shardings: Any = None, out_shardings: Any = None, donate_argnames: Optional[Union[str, Iterable[str]]] = None, keep_unused: bool = False, backend: str | None = None, inline: bool = False, compiler_options: dict[str, Any] | None = None)[source]#

Enhanced JIT compilation with persistent caching.

Drop-in replacement for jax.jit that caches compiled functions to disk for reuse across script runs, significantly reducing compilation overhead. Can be used as a bare decorator or called with keyword arguments.

Features:
  • Two-level caching (in-memory dict + persistent disk cache)

  • Automatic cache invalidation on source, hardware, or signature changes

  • Graceful fallback to standard jax.jit on errors

  • Support for all jax.jit parameters

Parameters
  • func – Function to JIT-compile and cache. When None, returns a partial decorator so that @ejit(...) syntax works.

  • static_argnums – Indices of arguments to treat as compile-time constants.

  • static_argnames – Names of arguments to treat as compile-time constants.

  • donate_argnums – Indices of arguments whose buffers may be donated to outputs.

  • in_shardings – Input sharding specifications for distributed execution.

  • out_shardings – Output sharding specifications for distributed execution.

  • donate_argnames – Names of arguments whose buffers may be donated.

  • keep_unused – Whether to keep unused arguments in the compiled function.

  • backend – JAX backend to use (e.g., 'cpu', 'gpu', 'tpu').

  • inline – Whether to inline the function into the caller.

  • compiler_options – Dictionary of additional XLA compiler options.

Returns

JIT-compiled function with persistent caching when EASYDEL_CACHE_COMPILES is set, otherwise a standard jax.jit-compiled function with XLA cache directory configured.

Note

Caching behavior is controlled by two environment variables:

  • EASYDEL_CACHE_COMPILES: Enable the custom two-level cache (default: off).

  • EASYDEL_RECOMPILE_FORCE: Force recompilation ignoring cache (default: off).

  • ALLOW_FULL_CACHE: Enable full XLA persistent caching (default: off).

Example

>>> @ejit
... def fast_matmul(a, b):
...     return a @ b
>>>
>>> result = fast_matmul(x, y)
>>>
>>> @ejit(static_argnums=(2,))
... def scaled_matmul(a, b, scale):
...     return a @ b * scale
ejkernel.callib._ejit.get_hash_of_lowering(lowered_func: Lowered) str[source]#

Generate a SHA-256 hash of a lowered JAX function.

Creates a deterministic hash based on the text representation of the lowered function, useful for cache key generation.

Parameters

lowered_func – JAX lowered function object.

Returns

Hexadecimal string of the SHA-256 hash.

ejkernel.callib._ejit.get_safe_hash_int(text, algorithm='md5')[source]#

Generate a deterministic integer hash of text using the specified algorithm.

Converts the input to a string, hashes it with the given hashlib algorithm, and returns the digest as a big-endian unsigned integer.

Parameters
  • text – Input to hash. Will be converted to str before hashing.

  • algorithm – Name of a hashlib algorithm (e.g., 'md5', 'sha256'). Defaults to 'md5'.

Returns

Non-negative integer representing the hash digest.

Raises
  • ValueError – If the specified algorithm is not supported by hashlib.

  • Exception – For any other hashing failure.

Example

>>> get_safe_hash_int("hello world")  
295242985...
>>> get_safe_hash_int("hello world", algorithm="sha256")  
805318394...
ejkernel.callib._ejit.hash_fn(self) int[source]#

Generate a deterministic integer hash for an object based on its attribute values.

Concatenates the string representations of all hashable-type attribute values (float, int, bool, dict, list) from the object’s __dict__ and produces an MD5-based integer hash.

Parameters

self – Any object with a __dict__ containing primitive-typed values.

Returns

Deterministic positive integer hash derived from the object’s attributes.

Note

Only attributes of type float, int, bool, dict, or list contribute to the hash. Other attribute types are silently ignored.

ejkernel.callib._ejit.load_cached_functions(verbose: bool = True) None[source]#

Pre-load all cached compiled functions from disk into memory.

Scans the cache directory and loads all valid compiled functions into the in-memory cache for faster subsequent lookups.

Parameters

verbose – If True, print status messages and warnings about loading.

Note

This is useful at startup to warm up the cache before running performance-critical code paths.

ejkernel.callib._ejit.load_compiled_fn(path: str | os.PathLike, prefix: str | None = None)[source]#

Load a previously saved compiled function from disk.

Parameters
  • path – Directory path where the compiled function was saved.

  • prefix – Optional prefix that was used when saving.

Returns

The deserialized compiled JAX function.

Raises
  • FileNotFoundError – If the compiled function file doesn’t exist.

  • pickle.UnpicklingError – If the file is corrupted.

ejkernel.callib._ejit.save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None)[source]#

Save a compiled JAX function to disk for later reuse.

Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions.

Parameters
  • path – Directory path where the compiled function will be saved. Will be created if it doesn’t exist.

  • fn – Compiled JAX function (output of lowered.compile()).

  • prefix – Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory.

Files Created:
  • {prefix}-compiled.executable: Serialized function and metadata

Example

>>>
>>> jitted = jax.jit(my_function)
>>> lowered = jitted.lower(sample_input)
>>> compiled = lowered.compile()
>>>
>>>
>>> from pathlib import Path
>>> cache_dir = Path("./my_cache")
>>> save_compiled_fn(cache_dir, compiled, prefix="model_v1")
>>>
>>>
Raises

Warning – If serialization fails (logged, not raised).

Notes

  • Compiled functions are hardware-specific

  • Large models may produce large cache files

  • Uses pickle for serialization (standard security caveats apply)

ejkernel.callib._ejit.smart_compile(lowered_func: Lowered, tag: str | None = None, verbose: bool = True, cache_key: tuple[str, tuple] | None = None) tuple[Compiled, tuple[str, tuple] | None][source]#

Compile a lowered JAX function with intelligent caching.

Attempts to load a previously compiled version from disk cache, falling back to fresh compilation if not found. Automatically caches newly compiled functions for future use.

Parameters
  • lowered_func – JAX lowered function to compile.

  • tag – Optional tag to include in the cache filename for organization.

  • verbose – If True, print warnings about cache operations.

  • cache_key – Optional custom cache key for the function signature.

Returns

Tuple of (compiled_function, cache_key) where cache_key may be updated if loaded from disk.

Note

Uses SHA-256 hash of the lowered function text for cache keys, combined with optional tag for namespacing.