ejkernel.ops.execution.profiler#
JAX profiler for performance analysis and autotuning.
This module provides a comprehensive profiler for JAX operations that captures execution traces, parses profile data, and provides detailed timing analysis. It is designed to work with JAX’s built-in profiling infrastructure and supports advanced features for autotuning hyperparameter optimization.
- Key Features:
Profile capture using JAX’s native profiling system
Nested event time accounting with interval merging
Regex-based event filtering for focused analysis
Device-specific profiling across GPU/TPU/CPU platforms
Statistical aggregation with outlier removal
- Classes:
Profiler: Main profiler class for capturing and analyzing JAX traces ProfilingError: Exception raised when profiling operations fail
- The profiler handles:
GPU and TPU timing formats (nanoseconds vs picoseconds)
Function identification by ID patterns for autotuning
Child event detection for accurate nested timing
Graceful fallback when TensorFlow hooks are unavailable
Example
>>> profiler = Profiler(prefix_filter='jit_', min_duration_ns=1000)
>>> timings = profiler.profile_time_by_function_id(
... closure, platform='gpu', total_calls_number=5
... )
- class ejkernel.ops.execution.profiler.Profiler(*, prefix_filter: str = 'jit_', event_filter_regex: str | None = None, min_duration_ns: float = 1000.0, max_events_per_profile: int | None = 10000, verbose: bool = False, require_tf: bool = False, silence_tf_cpp_logs: bool = True)[source]#
Bases:
objectJAX profile capture and parsing with nested-event accounting.
A comprehensive profiler for JAX operations that captures execution traces, parses profile data, and provides detailed timing analysis with support for nested event filtering and device-specific profiling.
This profiler is designed to work with JAX’s built-in profiling infrastructure and provides advanced features like regex-based event filtering, minimum duration thresholds, nested event time accounting, and graceful handling of missing TensorFlow profiler hooks.
The profiler detects TensorFlow Python profiler hook availability and gracefully skips profiling (allowing fallback to Python timing) if the hooks are not available. This makes it robust for deployment in environments where TensorFlow may not be fully configured.
- prefix_filter#
String prefix to filter events by name
- event_filter_regex#
Optional regex pattern for event filtering
- min_duration_ns#
Minimum event duration in nanoseconds to include
- max_events_per_profile#
Maximum number of events to process per profile
- verbose#
Enable verbose logging output
- require_tf#
Whether to require TensorFlow profiler hooks
- static find_device_plane_ids(p: Any, device_str: str) list[int][source]#
Find plane IDs corresponding to a specific device in profile data.
Searches through the profile’s execution planes to find those matching the specified device string. Planes represent different execution contexts (e.g., GPU, TPU, CPU) in the profiling data.
- Parameters
p – Profile data object containing execution planes
device_str – Device identifier to search for (case-insensitive)
- Returns
List of plane indices that match the device string
- Raises
ProfilingError – If no planes found for device or invalid profile structure
- get_events_from_plane(p: Any, plane_idx: int) dict[str, float][source]#
Extract and process events from a specific execution plane.
Processes all events from the specified plane, applying filtering criteria and calculating accurate timing with nested event accounting. This is the main method for extracting performance metrics from profile data.
- Parameters
p – Profile data object containing execution planes
plane_idx – Index of the specific plane to process
- Returns
Dictionary mapping event names to execution times in seconds
- Raises
ProfilingError – If plane index is invalid or event processing fails
- static parse_profile_from_bytes(profile_bytes: bytes)[source]#
Parse JAX profile data from serialized bytes.
Converts raw profile bytes (typically from .xplane.pb files) into a structured ProfileData object that can be analyzed for performance metrics.
- Parameters
profile_bytes – Raw profile data as bytes from JAX profiler output
- Returns
ProfileData object containing parsed profiling information
- Raises
ProfilingError – If profile data cannot be parsed or is corrupted
- profile_time_by_function_id(timing_closure: Callable[[], None], platform: str, total_calls_number: int) dict[int, tuple[float, float]][source]#
Profile function execution times across multiple iterations with statistical analysis.
Executes the provided closure multiple times under JAX profiler tracing, extracting timing data for functions identified by ID patterns. Provides statistical aggregation with outlier removal for reliable timing measurements.
This method is specifically designed for autotuning scenarios where multiple function variants (identified by numeric IDs) need to be compared. It requires TensorFlow profiler hooks to be available and will raise ProfilingError if they are not found, allowing fallback to Python-level timing.
- Parameters
timing_closure – Function to execute and profile (should call the functions to time)
platform – Target platform string (e.g., ‘gpu’, ‘tpu’, ‘cpu’) for device selection
total_calls_number – Number of profiling iterations to perform for statistical accuracy
- Returns
Dictionary mapping function IDs to (mean_time, std_time) tuples in seconds. Function IDs are extracted from event names matching ‘jit_autotune_fn_{id}’ pattern. Statistical outliers are removed if more than 2 measurements are available.
- Raises
ProfilingError – If TensorFlow profiler hooks are not available or profiling fails
RuntimeError – If no profile data is generated during execution