ejkernel.ops.utils.meta#
Metadata extraction and label processing for compiled JAX programs.
This module provides utilities for working with compilation metadata and labels embedded in JAX compiled programs. It enables extraction of operation identifiers and configuration mappings from compiled HLO code.
- Key Functions:
label: Generate standardized labels for operations extract_labels_from_hlo_text: Find all ejkernel labels in HLO text find_labels_in_lowered: Extract labels from lowered JAX computations labels_to_configs: Map found labels back to their configurations
- Label Format:
Labels follow the pattern: ‘ejkernel_ops#operation@version:hash’ Example: ‘ejkernel_ops#matmul@v1:1a2b3c4d5e6f7g8h’
- These utilities enable:
Tracking which operations were compiled with which configurations
Post-compilation analysis of optimization choices
Debugging and profiling of specific operation instances
Configuration recovery from compiled programs
- Example Usage:
>>> >>> op_label = label('matmul@v1', '1a2b3c4d5e6f7g8h') >>> print(op_label) >>> >>> >>> lowered = jax.jit(my_function).lower(args) >>> labels = find_labels_in_lowered(lowered) >>> >>> >>> configs = labels_to_configs(lowered, selector)
- ejkernel.ops.utils.meta.extract_labels_from_hlo_text(hlo_text: str) list[str][source]#
Find all ejkernel operation labels in HLO text.
Searches through HLO (High Level Operations) text to find all embedded ejkernel operation labels using regex pattern matching.
- Parameters
hlo_text – String containing HLO representation of compiled code
- Returns
List of found label strings
Note
The regex pattern matches the standard ejkernel label format: ‘ejkernel_ops#’ + operation_name + ‘:’ + 16-char hex hash
- ejkernel.ops.utils.meta.find_labels_in_lowered(lowered) list[str][source]#
Extract operation labels from a JAX lowered computation.
Converts the lowered computation to HLO text representation and extracts all embedded ejkernel operation labels.
- Parameters
lowered – JAX lowered computation object
- Returns
List of operation labels found in the compiled code
Note
First attempts to get HLO representation, falls back to string representation if HLO extraction fails.
- ejkernel.ops.utils.meta.label(op_id: str, call_hash: str, prefix: str | None = None) str[source]#
Generate a standardized label for an operation.
Creates a label string that uniquely identifies an operation instance for embedding in compiled code and later retrieval.
- Parameters
op_id – Operation identifier with version (e.g., ‘matmul@v1’)
call_hash – 16-character hash of the call signature
- Returns
Formatted label string following ejkernel convention
Examples
>>> label('matmul@v1', '1a2b3c4d5e6f7g8h') 'ejkernel_ops#matmul@v1:1a2b3c4d5e6f7g8h'
- ejkernel.ops.utils.meta.labels_to_configs(lowered, selector)[source]#
Extract labels from lowered computation and map them to configurations.
Finds all ejkernel operation labels in the compiled code and retrieves their corresponding configurations from the cache system.
- Parameters
lowered – JAX lowered computation object
selector – ConfigSelectorChain for cache lookups
- Returns
List of (label, config) tuples for all found operations
Note
Configurations are looked up first in memory cache, then in persistent cache if available. Operations without cached configurations will have None as their config value.