ejkernel.ops.utils.fingerprint#
Device fingerprinting and object hashing utilities for caching.
This module provides functions for generating stable, deterministic identifiers for JAX devices, array shapes, and complex Python objects. These identifiers are essential for the caching system to correctly match configurations across different runs and environments.
- Key Functions:
device_fingerprint: Generate stable device identifiers short_hash: Create short hashes from complex objects stable_json: Deterministic JSON serialization abstractify: Convert arrays to abstract shape/dtype specs sharding_fingerprint: Extract sharding information for caching
- The fingerprinting system handles:
JAX devices with platform version information
Complex nested data structures (PyTrees)
Functions and partial functions
Dataclasses and Pydantic models
JAX arrays with sharding information
NumPy arrays and dtypes
- These utilities ensure that:
Cache keys are consistent across program runs
Device-specific optimizations are properly isolated
Complex objects can be reliably serialized and hashed
Sharding information is preserved for distributed computation
- Example Usage:
>>> device_id = device_fingerprint() >>> config_hash = short_hash(my_config) >>> abstract_tree = abstractify(data_with_arrays) >>> cache_key = default_key_builder_with_sharding(invocation)
- ejkernel.ops.utils.fingerprint.abstractify(pytree: Any) Any[source]#
Convert a PyTree containing arrays to abstract shape/dtype specifications.
Transforms a nested data structure containing JAX or NumPy arrays into an abstract representation using ShapeDtypeStruct objects. This allows for consistent hashing and comparison based on array shapes and dtypes rather than actual array values.
- Parameters
pytree – Nested data structure potentially containing arrays
- Returns
PyTree with same structure but arrays replaced by ShapeDtypeStruct
Examples
>>> import jax.numpy as jnp >>> data = {'x': jnp.array([1, 2, 3]), 'y': 'scalar'} >>> abstract = abstractify(data) >>> >>>
Note
This is essential for creating cache keys based on array structure rather than values, allowing the same optimized configuration to be reused for arrays with the same shape and dtype but different values.
- ejkernel.ops.utils.fingerprint.default_key_builder_with_sharding(inv) str[source]#
Generate cache key that includes sharding information for device-aware caching.
Creates a comprehensive cache key that incorporates argument shapes, types, sharding information, and batch axes to ensure optimal cache matching in distributed computation environments.
- Parameters
inv – Function invocation object containing args, kwargs, and batch_axes
- Returns
Short hash string representing the complete invocation signature
Note
This key builder is more comprehensive than basic builders as it includes sharding information, making it suitable for distributed workloads where the same logical operation may have different optimal configurations depending on how data is sharded across devices.
- ejkernel.ops.utils.fingerprint.device_fingerprint(dev: jaxlib._jax.Device | None = None) str[source]#
Generate a stable identifier for a JAX device including platform version.
Creates a unique, stable identifier for JAX devices that includes both the device type and platform version information. This ensures that cached configurations are specific to the exact hardware and software environment.
- Parameters
dev – JAX device to fingerprint, uses default device if None
- Returns
String identifier like ‘gpu|cuda_12.0’, ‘tpu|v4’, or ‘cpu|’
Examples
>>> device_fingerprint() 'gpu|cuda_12.0' >>> device_fingerprint(jax.devices('cpu')[0]) 'cpu|'
Note
The format is ‘device_kind|platform_version’ where platform_version may be empty for some device types. This fingerprint is used as a key component in cache storage to ensure device-specific optimization.
- ejkernel.ops.utils.fingerprint.device_kind() str[source]#
Get the device kind (gpu, cpu, tpu) for the default device.
Returns a simple string identifier for the type of the default JAX device, without platform version information.
- Returns
‘gpu’, ‘cpu’, ‘tpu’, or ‘unknown’
- Return type
Device kind string
Examples
>>> device_kind() 'gpu'
Note
This is a simplified version of device_fingerprint() that only returns the device type without platform version details.
- ejkernel.ops.utils.fingerprint.get_device_platform(dev: jaxlib._jax.Device | None = None) str[source]#
Extract the platform identifier (gpu/tpu/cpu) from a JAX device.
- Parameters
dev – JAX device, uses default device if None
- Returns
‘gpu’, ‘tpu’, ‘cpu’, or ‘unknown’
- Return type
Platform string
Examples
>>> get_device_platform() 'gpu' >>> get_device_platform(jax.devices('tpu')[0]) 'tpu'
Note
This is used for platform-specific method dispatch in kernels.
- ejkernel.ops.utils.fingerprint.sharding_fingerprint(x: Any) Any[source]#
Extract sharding information from a JAX array for fingerprinting.
Creates a stable representation of how an array is sharded across devices, which is essential for device-aware caching in distributed computation.
- Parameters
x – Object to extract sharding information from
- Returns
String representation of sharding for JAX arrays, None otherwise
Note
The sharding representation is kept stable and compact to ensure consistent cache keys across different program executions.
- ejkernel.ops.utils.fingerprint.short_hash(obj: Any) str[source]#
Generate a short (16-character) hash from an object using stable JSON serialization.
Creates a compact, deterministic hash of any Python object by first converting it to stable JSON and then computing a SHA-256 hash.
- Parameters
obj – Object to hash (can be arbitrarily complex)
- Returns
16-character hexadecimal hash string
Examples
>>> short_hash({'a': 1, 'b': [2, 3]}) '1a2b3c4d5e6f7g8h' >>> short_hash(MyDataclass(x=1, y=2)) 'a1b2c3d4e5f6g7h8'
Note
Uses SHA-256 internally but truncates to 16 characters for compactness. The hash is deterministic across program runs for equivalent objects.
- ejkernel.ops.utils.fingerprint.stable_json(obj: Any) str[source]#
Deterministic JSON serialization that handles JAX/NumPy types and dataclasses.
Provides stable, deterministic JSON serialization for complex Python objects including JAX arrays, NumPy types, dataclasses, Pydantic models, and functions. The serialization is designed to produce identical output for equivalent objects across different program runs.
- Parameters
obj – Object to serialize (can be arbitrarily nested)
- Returns
Deterministic JSON string representation
- Supported Types:
Functions and methods (with module, name, and position info)
functools.partial objects (with function and bound arguments)
Callable objects (with class information)
Pydantic models (using model_dump())
Dataclasses (using asdict())
JAX ShapeDtypeStruct objects
NumPy dtypes and scalar types
Standard Python types
Note
The JSON output uses sorted keys and compact separators to ensure consistent formatting. Function objects are serialized with their module, qualified name, and source location for stability.