ejkernel.ops.execution.profiler#

JAX profiler for performance analysis and autotuning.

This module provides a comprehensive profiler for JAX operations that captures execution traces, parses profile data, and provides detailed timing analysis. It is designed to work with JAX’s built-in profiling infrastructure and supports advanced features for autotuning hyperparameter optimization.

Key Features:
  • Profile capture using JAX’s native profiling system

  • Nested event time accounting with interval merging

  • Regex-based event filtering for focused analysis

  • Device-specific profiling across GPU/TPU/CPU platforms

  • Statistical aggregation with outlier removal

Classes:

Profiler: Main profiler class for capturing and analyzing JAX traces ProfilingError: Exception raised when profiling operations fail

The profiler handles:
  • GPU and TPU timing formats (nanoseconds vs picoseconds)

  • Function identification by ID patterns for autotuning

  • Child event detection for accurate nested timing

  • Graceful fallback when TensorFlow hooks are unavailable

Example

>>> profiler = Profiler(prefix_filter='jit_', min_duration_ns=1000)
>>> timings = profiler.profile_time_by_function_id(
...     closure, platform='gpu', total_calls_number=5
... )
class ejkernel.ops.execution.profiler.Profiler(*, prefix_filter: str = 'jit_', event_filter_regex: str | None = None, min_duration_ns: float = 1000.0, max_events_per_profile: int | None = 10000, verbose: bool = False, require_tf: bool = False, silence_tf_cpp_logs: bool = True)[source]#

Bases: object

JAX profile capture and parsing with nested-event accounting.

A comprehensive profiler for JAX operations that captures execution traces, parses profile data, and provides detailed timing analysis with support for nested event filtering and device-specific profiling.

This profiler is designed to work with JAX’s built-in profiling infrastructure and provides advanced features like regex-based event filtering, minimum duration thresholds, nested event time accounting, and graceful handling of missing TensorFlow profiler hooks.

The profiler detects TensorFlow Python profiler hook availability and gracefully skips profiling (allowing fallback to Python timing) if the hooks are not available. This makes it robust for deployment in environments where TensorFlow may not be fully configured.

prefix_filter#

String prefix to filter events by name

event_filter_regex#

Optional regex pattern for event filtering

min_duration_ns#

Minimum event duration in nanoseconds to include

max_events_per_profile#

Maximum number of events to process per profile

verbose#

Enable verbose logging output

require_tf#

Whether to require TensorFlow profiler hooks

static find_device_plane_ids(p: Any, device_str: str) list[int][source]#

Find plane IDs corresponding to a specific device in profile data.

Searches through the profile’s execution planes to find those matching the specified device string. Planes represent different execution contexts (e.g., GPU, TPU, CPU) in the profiling data.

Parameters
  • p – Profile data object containing execution planes

  • device_str – Device identifier to search for (case-insensitive)

Returns

List of plane indices that match the device string

Raises

ProfilingError – If no planes found for device or invalid profile structure

get_events_from_plane(p: Any, plane_idx: int) dict[str, float][source]#

Extract and process events from a specific execution plane.

Processes all events from the specified plane, applying filtering criteria and calculating accurate timing with nested event accounting. This is the main method for extracting performance metrics from profile data.

Parameters
  • p – Profile data object containing execution planes

  • plane_idx – Index of the specific plane to process

Returns

Dictionary mapping event names to execution times in seconds

Raises

ProfilingError – If plane index is invalid or event processing fails

static parse_profile_from_bytes(profile_bytes: bytes)[source]#

Parse JAX profile data from serialized bytes.

Converts raw profile bytes (typically from .xplane.pb files) into a structured ProfileData object that can be analyzed for performance metrics.

Parameters

profile_bytes – Raw profile data as bytes from JAX profiler output

Returns

ProfileData object containing parsed profiling information

Raises

ProfilingError – If profile data cannot be parsed or is corrupted

profile_time_by_function_id(timing_closure: Callable[[], None], platform: str, total_calls_number: int) dict[int, tuple[float, float]][source]#

Profile function execution times across multiple iterations with statistical analysis.

Executes the provided closure multiple times under JAX profiler tracing, extracting timing data for functions identified by ID patterns. Provides statistical aggregation with outlier removal for reliable timing measurements.

This method is specifically designed for autotuning scenarios where multiple function variants (identified by numeric IDs) need to be compared. It requires TensorFlow profiler hooks to be available and will raise ProfilingError if they are not found, allowing fallback to Python-level timing.

Parameters
  • timing_closure – Function to execute and profile (should call the functions to time)

  • platform – Target platform string (e.g., ‘gpu’, ‘tpu’, ‘cpu’) for device selection

  • total_calls_number – Number of profiling iterations to perform for statistical accuracy

Returns

Dictionary mapping function IDs to (mean_time, std_time) tuples in seconds. Function IDs are extracted from event names matching ‘jit_autotune_fn_{id}’ pattern. Statistical outliers are removed if more than 2 measurements are available.

Raises
  • ProfilingError – If TensorFlow profiler hooks are not available or profiling fails

  • RuntimeError – If no profile data is generated during execution

exception ejkernel.ops.execution.profiler.ProfilingError[source]#

Bases: Exception

Exception raised when profiling operations fail.