ejkernel.ops.utils.meta#

Metadata extraction and label processing for compiled JAX programs.

This module provides utilities for working with compilation metadata and labels embedded in JAX compiled programs. It enables extraction of operation identifiers and configuration mappings from compiled HLO code.

Key Functions:

label: Generate standardized labels for operations extract_labels_from_hlo_text: Find all ejkernel labels in HLO text find_labels_in_lowered: Extract labels from lowered JAX computations labels_to_configs: Map found labels back to their configurations

Label Format:

Labels follow the pattern: ‘ejkernel_ops#operation@version:hash’ Example: ‘ejkernel_ops#matmul@v1:1a2b3c4d5e6f7g8h’

These utilities enable:
  • Tracking which operations were compiled with which configurations

  • Post-compilation analysis of optimization choices

  • Debugging and profiling of specific operation instances

  • Configuration recovery from compiled programs

Example Usage:
>>>
>>> op_label = label('matmul@v1', '1a2b3c4d5e6f7g8h')
>>> print(op_label)
>>>
>>>
>>> lowered = jax.jit(my_function).lower(args)
>>> labels = find_labels_in_lowered(lowered)
>>>
>>>
>>> configs = labels_to_configs(lowered, selector)
ejkernel.ops.utils.meta.extract_labels_from_hlo_text(hlo_text: str) list[str][source]#

Find all ejkernel operation labels in HLO text.

Searches through HLO (High Level Operations) text to find all embedded ejkernel operation labels using regex pattern matching.

Parameters

hlo_text – String containing HLO representation of compiled code

Returns

List of found label strings

Note

The regex pattern matches the standard ejkernel label format: ‘ejkernel_ops#’ + operation_name + ‘:’ + 16-char hex hash

ejkernel.ops.utils.meta.find_labels_in_lowered(lowered) list[str][source]#

Extract operation labels from a JAX lowered computation.

Converts the lowered computation to HLO text representation and extracts all embedded ejkernel operation labels.

Parameters

lowered – JAX lowered computation object

Returns

List of operation labels found in the compiled code

Note

First attempts to get HLO representation, falls back to string representation if HLO extraction fails.

ejkernel.ops.utils.meta.label(op_id: str, call_hash: str, prefix: str | None = None) str[source]#

Generate a standardized label for an operation.

Creates a label string that uniquely identifies an operation instance for embedding in compiled code and later retrieval.

Parameters
  • op_id – Operation identifier with version (e.g., ‘matmul@v1’)

  • call_hash – 16-character hash of the call signature

Returns

Formatted label string following ejkernel convention

Examples

>>> label('matmul@v1', '1a2b3c4d5e6f7g8h')
'ejkernel_ops#matmul@v1:1a2b3c4d5e6f7g8h'
ejkernel.ops.utils.meta.labels_to_configs(lowered, selector)[source]#

Extract labels from lowered computation and map them to configurations.

Finds all ejkernel operation labels in the compiled code and retrieves their corresponding configurations from the cache system.

Parameters
  • lowered – JAX lowered computation object

  • selector – ConfigSelectorChain for cache lookups

Returns

List of (label, config) tuples for all found operations

Note

Configurations are looked up first in memory cache, then in persistent cache if available. Operations without cached configurations will have None as their config value.