ejkernel.callib._ejit#

Compilation utilities for JAX function optimization.

Provides enhanced JIT compilation with persistent caching to disk, reducing compilation overhead across script runs.

Functions:

ejit: Enhanced JIT with persistent caching save_compiled_fn: Save compiled function to disk load_compiled_fn: Load compiled function from disk load_cached_functions: Load multiple cached functions smart_compile: Smart compilation with auto-caching hash_fn: Generate hash for function signature

Constants:

RECOMPILE_FORCE: Force recompilation flag ECACHE_COMPILES: Enable compilation caching CACHE_DIR: Cache directory path COMPILE_FUNC_DIR: Compiled functions directory COMPILED_CACHE: In-memory cache of compiled functions

Key Features:
  • Persistent disk caching of compiled functions

  • Automatic cache invalidation on changes

  • Hardware-specific signatures

  • Two-level caching (memory + disk)

  • Graceful fallback on errors

Example

>>> from easydel.utils.compiling_utils import ejit
>>>
>>> @ejit
... def optimized_fn(x, y):
...     return x @ y + x.T @ y.T
>>>
>>>
>>> result = optimized_fn(a, b)
>>>
>>> result = optimized_fn(a, b)
ejkernel.callib._ejit.ejit(func: Optional[Callable[[P], R]] = None, *, static_argnums: Optional[Union[int, Sequence[int]]] = None, static_argnames: Optional[Union[str, Iterable[str]]] = None, donate_argnums: Optional[Union[int, Sequence[int]]] = None, in_shardings: Any = None, out_shardings: Any = None, donate_argnames: Optional[Union[str, Iterable[str]]] = None, keep_unused: bool = False, backend: str | None = None, inline: bool = False, compiler_options: dict[str, Any] | None = None)[source]#

Enhanced JIT compilation with persistent caching.

Drop-in replacement for jax.jit that caches compiled functions to disk for reuse across script runs, significantly reducing compilation overhead.

Features:
  • Two-level caching (in-memory LRU + persistent disk cache)

  • Automatic cache invalidation on signature changes

  • Graceful fallback to standard jax.jit on errors

  • Support for all jax.jit parameters

Parameters
  • func – Function to JIT-compile and cache.

  • static_argnums – Indices of static arguments.

  • static_argnames – Names of static arguments.

  • donate_argnums – Indices of donated arguments.

  • in_shardings – Input sharding specifications.

  • out_shardings – Output sharding specifications.

Returns

JIT-compiled function with caching.

Example

>>> @ejit
... def fast_matmul(a, b):
...     return a @ b
>>>
>>> result = fast_matmul(x, y)
>>>
ejkernel.callib._ejit.get_hash_of_lowering(lowered_func: Lowered) str[source]#

Generate a SHA-256 hash of a lowered JAX function.

Creates a deterministic hash based on the text representation of the lowered function, useful for cache key generation.

Parameters

lowered_func – JAX lowered function object.

Returns

Hexadecimal string of the SHA-256 hash.

ejkernel.callib._ejit.get_safe_hash_int(text, algorithm='md5')[source]#

Generate a hash of text using specified algorithm with safety checks.

ejkernel.callib._ejit.hash_fn(self) int[source]#

Generate a hash for an object based on its dictionary values.

ejkernel.callib._ejit.load_cached_functions(verbose: bool = True) None[source]#

Pre-load all cached compiled functions from disk into memory.

Scans the cache directory and loads all valid compiled functions into the in-memory cache for faster subsequent lookups.

Parameters

verbose – If True, print status messages and warnings about loading.

Note

This is useful at startup to warm up the cache before running performance-critical code paths.

ejkernel.callib._ejit.load_compiled_fn(path: str | os.PathLike, prefix: str | None = None)[source]#

Load a previously saved compiled function from disk.

Parameters
  • path – Directory path where the compiled function was saved.

  • prefix – Optional prefix that was used when saving.

Returns

The deserialized compiled JAX function.

Raises
  • FileNotFoundError – If the compiled function file doesn’t exist.

  • pickle.UnpicklingError – If the file is corrupted.

ejkernel.callib._ejit.save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None)[source]#

Save a compiled JAX function to disk for later reuse.

Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions.

Parameters
  • path – Directory path where the compiled function will be saved. Will be created if it doesn’t exist.

  • fn – Compiled JAX function (output of lowered.compile()).

  • prefix – Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory.

Files Created:
  • {prefix}-compiled.executable: Serialized function and metadata

Example

>>>
>>> jitted = jax.jit(my_function)
>>> lowered = jitted.lower(sample_input)
>>> compiled = lowered.compile()
>>>
>>>
>>> from pathlib import Path
>>> cache_dir = Path("./my_cache")
>>> save_compiled_fn(cache_dir, compiled, prefix="model_v1")
>>>
>>>
Raises

Warning – If serialization fails (logged, not raised).

Notes

  • Compiled functions are hardware-specific

  • Large models may produce large cache files

  • Uses pickle for serialization (standard security caveats apply)

ejkernel.callib._ejit.smart_compile(lowered_func: Lowered, tag: str | None = None, verbose: bool = True, cache_key: tuple[str, tuple] | None = None) tuple[Compiled, tuple[str, tuple] | None][source]#

Compile a lowered JAX function with intelligent caching.

Attempts to load a previously compiled version from disk cache, falling back to fresh compilation if not found. Automatically caches newly compiled functions for future use.

Parameters
  • lowered_func – JAX lowered function to compile.

  • tag – Optional tag to include in the cache filename for organization.

  • verbose – If True, print warnings about cache operations.

  • cache_key – Optional custom cache key for the function signature.

Returns

Tuple of (compiled_function, cache_key) where cache_key may be updated if loaded from disk.

Note

Uses SHA-256 hash of the lowered function text for cache keys, combined with optional tag for namespacing.