ejkernel.callib._ejit#
Compilation utilities for JAX function optimization.
Provides enhanced JIT compilation with persistent caching to disk, reducing compilation overhead across script runs.
- Functions:
ejit: Enhanced JIT with persistent caching save_compiled_fn: Save compiled function to disk load_compiled_fn: Load compiled function from disk load_cached_functions: Load multiple cached functions smart_compile: Smart compilation with auto-caching hash_fn: Generate hash for function signature
- Constants:
RECOMPILE_FORCE: Force recompilation flag ECACHE_COMPILES: Enable compilation caching CACHE_DIR: Cache directory path COMPILE_FUNC_DIR: Compiled functions directory COMPILED_CACHE: In-memory cache of compiled functions
- Key Features:
Persistent disk caching of compiled functions
Automatic cache invalidation on changes
Hardware-specific signatures
Two-level caching (memory + disk)
Graceful fallback on errors
Example
>>> from easydel.utils.compiling_utils import ejit
>>>
>>> @ejit
... def optimized_fn(x, y):
... return x @ y + x.T @ y.T
>>>
>>>
>>> result = optimized_fn(a, b)
>>>
>>> result = optimized_fn(a, b)
- ejkernel.callib._ejit.ejit(func: Optional[Callable[[P], R]] = None, *, static_argnums: Optional[Union[int, Sequence[int]]] = None, static_argnames: Optional[Union[str, Iterable[str]]] = None, donate_argnums: Optional[Union[int, Sequence[int]]] = None, in_shardings: Any = None, out_shardings: Any = None, donate_argnames: Optional[Union[str, Iterable[str]]] = None, keep_unused: bool = False, backend: str | None = None, inline: bool = False, compiler_options: dict[str, Any] | None = None)[source]#
Enhanced JIT compilation with persistent caching.
Drop-in replacement for jax.jit that caches compiled functions to disk for reuse across script runs, significantly reducing compilation overhead.
- Features:
Two-level caching (in-memory LRU + persistent disk cache)
Automatic cache invalidation on signature changes
Graceful fallback to standard jax.jit on errors
Support for all jax.jit parameters
- Parameters
func – Function to JIT-compile and cache.
static_argnums – Indices of static arguments.
static_argnames – Names of static arguments.
donate_argnums – Indices of donated arguments.
in_shardings – Input sharding specifications.
out_shardings – Output sharding specifications.
- Returns
JIT-compiled function with caching.
Example
>>> @ejit ... def fast_matmul(a, b): ... return a @ b >>> >>> result = fast_matmul(x, y) >>>
- ejkernel.callib._ejit.get_hash_of_lowering(lowered_func: Lowered) str[source]#
Generate a SHA-256 hash of a lowered JAX function.
Creates a deterministic hash based on the text representation of the lowered function, useful for cache key generation.
- Parameters
lowered_func – JAX lowered function object.
- Returns
Hexadecimal string of the SHA-256 hash.
- ejkernel.callib._ejit.get_safe_hash_int(text, algorithm='md5')[source]#
Generate a hash of text using specified algorithm with safety checks.
- ejkernel.callib._ejit.hash_fn(self) int[source]#
Generate a hash for an object based on its dictionary values.
- ejkernel.callib._ejit.load_cached_functions(verbose: bool = True) None[source]#
Pre-load all cached compiled functions from disk into memory.
Scans the cache directory and loads all valid compiled functions into the in-memory cache for faster subsequent lookups.
- Parameters
verbose – If True, print status messages and warnings about loading.
Note
This is useful at startup to warm up the cache before running performance-critical code paths.
- ejkernel.callib._ejit.load_compiled_fn(path: str | os.PathLike, prefix: str | None = None)[source]#
Load a previously saved compiled function from disk.
- Parameters
path – Directory path where the compiled function was saved.
prefix – Optional prefix that was used when saving.
- Returns
The deserialized compiled JAX function.
- Raises
FileNotFoundError – If the compiled function file doesn’t exist.
pickle.UnpicklingError – If the file is corrupted.
- ejkernel.callib._ejit.save_compiled_fn(path: str | os.PathLike, fn: Compiled, prefix: str | None = None)[source]#
Save a compiled JAX function to disk for later reuse.
Serializes a compiled function along with its input/output tree structures, allowing it to be loaded and executed in future Python sessions.
- Parameters
path – Directory path where the compiled function will be saved. Will be created if it doesn’t exist.
fn – Compiled JAX function (output of lowered.compile()).
prefix – Optional prefix for the filename. Useful for organizing multiple compiled functions in the same directory.
- Files Created:
{prefix}-compiled.executable: Serialized function and metadata
Example
>>> >>> jitted = jax.jit(my_function) >>> lowered = jitted.lower(sample_input) >>> compiled = lowered.compile() >>> >>> >>> from pathlib import Path >>> cache_dir = Path("./my_cache") >>> save_compiled_fn(cache_dir, compiled, prefix="model_v1") >>> >>>
- Raises
Warning – If serialization fails (logged, not raised).
Notes
Compiled functions are hardware-specific
Large models may produce large cache files
Uses pickle for serialization (standard security caveats apply)
- ejkernel.callib._ejit.smart_compile(lowered_func: Lowered, tag: str | None = None, verbose: bool = True, cache_key: tuple[str, tuple] | None = None) tuple[Compiled, tuple[str, tuple] | None][source]#
Compile a lowered JAX function with intelligent caching.
Attempts to load a previously compiled version from disk cache, falling back to fresh compilation if not found. Automatically caches newly compiled functions for future use.
- Parameters
lowered_func – JAX lowered function to compile.
tag – Optional tag to include in the cache filename for organization.
verbose – If True, print warnings about cache operations.
cache_key – Optional custom cache key for the function signature.
- Returns
Tuple of (compiled_function, cache_key) where cache_key may be updated if loaded from disk.
Note
Uses SHA-256 hash of the lowered function text for cache keys, combined with optional tag for namespacing.