Source code for ejkernel.ops.execution.profiler

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""JAX profiler for performance analysis and autotuning.

This module provides a comprehensive profiler for JAX operations that captures
execution traces, parses profile data, and provides detailed timing analysis.
It is designed to work with JAX's built-in profiling infrastructure and
supports advanced features for autotuning hyperparameter optimization.

Key Features:
    - Profile capture using JAX's native profiling system
    - Nested event time accounting with interval merging
    - Regex-based event filtering for focused analysis
    - Device-specific profiling across GPU/TPU/CPU platforms
    - Statistical aggregation with outlier removal

Classes:
    Profiler: Main profiler class for capturing and analyzing JAX traces
    ProfilingError: Exception raised when profiling operations fail

The profiler handles:
    - GPU and TPU timing formats (nanoseconds vs picoseconds)
    - Function identification by ID patterns for autotuning
    - Child event detection for accurate nested timing
    - Graceful fallback when TensorFlow hooks are unavailable

Example:
    >>> profiler = Profiler(prefix_filter='jit_', min_duration_ns=1000)
    >>> timings = profiler.profile_time_by_function_id(
    ...     closure, platform='gpu', total_calls_number=5
    ... )
"""

from __future__ import annotations

import importlib
import os
import re
import tempfile
import warnings
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any

import jax
import numpy as np
from jax.profiler import ProfileData


[docs]class ProfilingError(Exception): """Exception raised when profiling operations fail.""" pass
[docs]class Profiler: """JAX profile capture and parsing with nested-event accounting. A comprehensive profiler for JAX operations that captures execution traces, parses profile data, and provides detailed timing analysis with support for nested event filtering and device-specific profiling. This profiler is designed to work with JAX's built-in profiling infrastructure and provides advanced features like regex-based event filtering, minimum duration thresholds, nested event time accounting, and graceful handling of missing TensorFlow profiler hooks. The profiler detects TensorFlow Python profiler hook availability and gracefully skips profiling (allowing fallback to Python timing) if the hooks are not available. This makes it robust for deployment in environments where TensorFlow may not be fully configured. Attributes: prefix_filter: String prefix to filter events by name event_filter_regex: Optional regex pattern for event filtering min_duration_ns: Minimum event duration in nanoseconds to include max_events_per_profile: Maximum number of events to process per profile verbose: Enable verbose logging output require_tf: Whether to require TensorFlow profiler hooks """ def __init__( self, *, prefix_filter: str = "jit_", event_filter_regex: str | None = None, min_duration_ns: float = 1000.0, max_events_per_profile: int | None = 10000, verbose: bool = False, require_tf: bool = False, silence_tf_cpp_logs: bool = True, ): """Initialize the JAX profiler with filtering and processing options. Args: prefix_filter: String prefix to filter events by name (default: "jit_") event_filter_regex: Optional regex pattern for advanced event filtering min_duration_ns: Minimum event duration in nanoseconds to include in results max_events_per_profile: Maximum number of events to process per profile to prevent memory issues with large traces verbose: Enable verbose logging for debugging profiling operations require_tf: If True, require TensorFlow profiler hooks to be available silence_tf_cpp_logs: If True, set TF_CPP_MIN_LOG_LEVEL=3 to reduce TF noise """ self.prefix_filter = prefix_filter self.event_filter_regex = event_filter_regex self.min_duration_ns = min_duration_ns self.max_events_per_profile = max_events_per_profile self.verbose = verbose self._pattern = re.compile(event_filter_regex) if event_filter_regex is not None else None self.require_tf = require_tf self._tf_avail_cache: bool | None = None if silence_tf_cpp_logs and "TF_CPP_MIN_LOG_LEVEL" not in os.environ: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" def _tf_python_profiler_available(self) -> bool: """Check once whether tensorflow.python.profiler.trace can be imported. Caches the result to avoid repeated import attempts. This is used to gracefully handle cases where TensorFlow is not available or the profiler hooks are not properly installed. Returns: True if TensorFlow profiler hooks are available, False otherwise """ if self._tf_avail_cache is not None: return self._tf_avail_cache try: importlib.import_module("tensorflow.python.profiler.trace") self._tf_avail_cache = True except Exception: self._tf_avail_cache = False return self._tf_avail_cache
[docs] @staticmethod def parse_profile_from_bytes(profile_bytes: bytes): # type: ignore """Parse JAX profile data from serialized bytes. Converts raw profile bytes (typically from .xplane.pb files) into a structured ProfileData object that can be analyzed for performance metrics. Args: profile_bytes: Raw profile data as bytes from JAX profiler output Returns: ProfileData object containing parsed profiling information Raises: ProfilingError: If profile data cannot be parsed or is corrupted """ try: return ProfileData.from_serialized_xspace(profile_bytes) except Exception as e: raise ProfilingError(f"Failed to parse profile data: {e}") from e
[docs] @staticmethod def find_device_plane_ids(p: Any, device_str: str) -> list[int]: """Find plane IDs corresponding to a specific device in profile data. Searches through the profile's execution planes to find those matching the specified device string. Planes represent different execution contexts (e.g., GPU, TPU, CPU) in the profiling data. Args: p: Profile data object containing execution planes device_str: Device identifier to search for (case-insensitive) Returns: List of plane indices that match the device string Raises: ProfilingError: If no planes found for device or invalid profile structure """ try: plane_ids = [i for i, plane in enumerate(p.planes) if device_str.lower() in plane.name.lower()] if not plane_ids: available_devices = [plane.name for plane in p.planes] raise ProfilingError( f"No planes found for device '{device_str}'. Available devices: {available_devices}" ) return plane_ids except AttributeError as e: raise ProfilingError(f"Invalid profile structure: {e}") from e
@staticmethod def _get_stat_value(stat, metadata): """Extract the actual value from a profile statistic object. Profile statistics can store values in different formats (double, int64, uint64, ref, bytes, str). This method attempts to extract the actual value regardless of the storage format. Args: stat: Profile statistic object with various value fields metadata: Optional metadata mapping for reference values Returns: The extracted value, or None if no valid value found """ try: if getattr(stat, "ref_value", 0) != 0: return metadata[stat.ref_value].name if metadata is not None else stat.ref_value for key in ("double", "int64", "uint64", "ref"): v = getattr(stat, f"{key}_value", 0) if v != 0: return v for key in ("bytes", "str"): v = getattr(stat, f"{key}_value", b"" if key == "bytes" else "") if v: return v except Exception: pass return None @classmethod def _parse_stats(cls, stats, stat_metadata): """Parse all statistics from a profile event into a dictionary. Converts the raw statistics data from profile events into a more accessible dictionary format, using metadata to resolve statistic names. Args: stats: Collection of statistic objects from a profile event stat_metadata: Optional metadata mapping for statistic names Returns: Dictionary mapping statistic names to their values """ stats_list = list(stats) if stat_metadata is not None: return {stat_metadata[s.metadata_id].name: cls._get_stat_value(s, stat_metadata) for s in stats_list} return {getattr(s, "metadata_id", i): cls._get_stat_value(s, None) for i, s in enumerate(stats_list)} @classmethod def _parse_event(cls, event, event_metadata, stat_metadata, line_name: str = ""): """Parse a single profile event into a structured dictionary. Extracts timing information, statistics, and metadata from a raw profile event object, creating a standardized representation for analysis. Handles both GPU (uses start_ns/duration_ns) and TPU (uses offset_ps/duration_ps) timing formats. Args: event: Raw profile event object event_metadata: Optional metadata mapping for event names stat_metadata: Optional metadata mapping for statistic names line_name: Name of the execution line this event belongs to Returns: Dictionary containing parsed event data with timing and statistics """ if event_metadata is not None: name = event_metadata[event.metadata_id].name else: name = event.name with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=r".*event_stats.*__module__ attribute", category=DeprecationWarning, ) stats = cls._parse_stats(event.stats, stat_metadata) name = stats.get("hlo_module", name) program_id = stats.get("program_id", stats.get("run_id")) scope_range_id = stats.get("scope_range_id", "None") key = f"{name}({program_id}-{scope_range_id})" if hasattr(event, "duration_ps"): start_ps = int(event.offset_ps) end_ps = start_ps + int(event.duration_ps) dur_ps = int(event.duration_ps) else: start_ps = int(event.start_ns * 1000) end_ps = start_ps + int(event.duration_ns * 1000) dur_ps = int(event.duration_ns * 1000) stats["start_ps"] = start_ps stats["end_ps"] = end_ps stats["duration_ps"] = dur_ps return dict(unified_name=key, fusion=name, line_name=line_name, **stats) @staticmethod def _find_children( own_name: str, start_ps: int, end_ps: int, events_sorted: list[dict[str, Any]], starts_sorted: np.ndarray, ): """Find all child events fully contained within a parent event's timespan. Uses binary search on precomputed sorted start times to efficiently locate all events that occur entirely within the specified time range, excluding the parent event itself. This is crucial for accurate nested event timing analysis. Args: own_name: Name of the parent event to exclude from results start_ps: Start time in picoseconds of the parent event end_ps: End time in picoseconds of the parent event events_sorted: List of all events sorted by start time starts_sorted: Precomputed numpy array of start times for binary search optimization Returns: List of child events that occur entirely within the time range """ idx = int(np.searchsorted(starts_sorted, start_ps, side="left")) children = [] for ev in events_sorted[idx:]: s = ev["start_ps"] if s > end_ps: break if ev["unified_name"] == own_name: continue if s >= start_ps and ev["end_ps"] <= end_ps: children.append(ev) return children @staticmethod def _sum_events(events): """Calculate total time covered by a collection of events using interval merging. Handles overlapping events by merging their time intervals to avoid double-counting execution time. This is essential for accurate nested event timing analysis where child events may overlap. Uses an efficient interval merging algorithm that sorts intervals by start time and then merges overlapping or adjacent intervals. Args: events: Collection of events with start_ps and end_ps timing data Returns: Total time in picoseconds covered by all events (with overlaps merged) """ if not events: return 0 intervals = sorted((int(e["start_ps"]), int(e["end_ps"])) for e in events) total = 0 cur_s, cur_e = intervals[0] for s, e in intervals[1:]: if s > cur_e: total += cur_e - cur_s cur_s, cur_e = s, e else: cur_e = max(cur_e, e) total += cur_e - cur_s return total
[docs] def get_events_from_plane(self, p: Any, plane_idx: int) -> dict[str, float]: """Extract and process events from a specific execution plane. Processes all events from the specified plane, applying filtering criteria and calculating accurate timing with nested event accounting. This is the main method for extracting performance metrics from profile data. Args: p: Profile data object containing execution planes plane_idx: Index of the specific plane to process Returns: Dictionary mapping event names to execution times in seconds Raises: ProfilingError: If plane index is invalid or event processing fails """ try: planes = list(p.planes) if plane_idx >= len(planes): raise ProfilingError(f"Plane index {plane_idx} out of range (0-{len(planes) - 1})") plane = planes[plane_idx] event_metadata = getattr(plane, "event_metadata", None) stat_metadata = getattr(plane, "stat_metadata", None) min_duration_ps = int(self.min_duration_ns * 1000) all_events: list[dict[str, Any]] = [] processed = 0 for line in plane.lines: for event in line.events: if self.max_events_per_profile is not None and processed >= self.max_events_per_profile: break ev = self._parse_event(event, event_metadata, stat_metadata, line_name=line.name) if ev["duration_ps"] < min_duration_ps: continue all_events.append(ev) processed += 1 if self.max_events_per_profile is not None and processed >= self.max_events_per_profile: break if not all_events: return {} events_sorted = sorted(all_events, key=lambda x: x["start_ps"]) starts_sorted = np.fromiter((e["start_ps"] for e in events_sorted), dtype=np.int64, count=len(events_sorted)) timed_events: dict[str, float] = {} for ev in events_sorted: name = ev["unified_name"] if self.prefix_filter and not name.startswith(self.prefix_filter): continue children = self._find_children(name, ev["start_ps"], ev["end_ps"], events_sorted, starts_sorted) if self._pattern is not None: children = [ch for ch in children if self._pattern.search(ch["unified_name"]) is not None] children_duration_ps = self._sum_events(children) duration_seconds = children_duration_ps / 1e12 else: duration_seconds = (ev["end_ps"] - ev["start_ps"]) / 1e12 if duration_seconds >= 0: timed_events[name] = float(duration_seconds) return timed_events except re.error as e: raise ProfilingError(f"Invalid regex pattern '{self.event_filter_regex}': {e}") from e except Exception as e: raise ProfilingError(f"Failed to extract events from plane {plane_idx}: {e}") from e
[docs] def profile_time_by_function_id( self, timing_closure: Callable[[], None], platform: str, total_calls_number: int, ) -> dict[int, tuple[float, float]]: """Profile function execution times across multiple iterations with statistical analysis. Executes the provided closure multiple times under JAX profiler tracing, extracting timing data for functions identified by ID patterns. Provides statistical aggregation with outlier removal for reliable timing measurements. This method is specifically designed for autotuning scenarios where multiple function variants (identified by numeric IDs) need to be compared. It requires TensorFlow profiler hooks to be available and will raise ProfilingError if they are not found, allowing fallback to Python-level timing. Args: timing_closure: Function to execute and profile (should call the functions to time) platform: Target platform string (e.g., 'gpu', 'tpu', 'cpu') for device selection total_calls_number: Number of profiling iterations to perform for statistical accuracy Returns: Dictionary mapping function IDs to (mean_time, std_time) tuples in seconds. Function IDs are extracted from event names matching 'jit_autotune_fn_{id}' pattern. Statistical outliers are removed if more than 2 measurements are available. Raises: ProfilingError: If TensorFlow profiler hooks are not available or profiling fails RuntimeError: If no profile data is generated during execution """ if self.require_tf and not self._tf_python_profiler_available(): raise ProfilingError("TensorFlow Python profiler hooks are not available") if not self._tf_python_profiler_available(): raise ProfilingError("Profiler not available (missing tensorflow.python.profiler.trace)") function_timings: dict[int, list[float]] = {} name_re = re.compile(r"^jit_autotune_fn_([0-9]+).*") for _ in range(total_calls_number): now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") with tempfile.TemporaryDirectory(prefix=f"tuning_profile_{now}_") as tmpdir: with jax.profiler.trace(tmpdir): timing_closure() profile_files = sorted(Path(tmpdir).glob("**/*.xplane.pb"), key=lambda f: f.stat().st_mtime) if not profile_files: raise RuntimeError("No profile was created.") latest_profile = profile_files[-1] profile_proto = self.parse_profile_from_bytes(latest_profile.read_bytes()) device_plane_id = self.find_device_plane_ids(profile_proto, platform)[0] profile_events = self.get_events_from_plane(profile_proto, device_plane_id) for k, dur in profile_events.items(): m = name_re.match(k) if not m: continue key = int(m.group(1)) function_timings.setdefault(key, []).append(dur) agg: dict[int, tuple[float, float]] = {} for key, durations in function_timings.items(): if len(durations) > 2: durations = sorted(durations)[1:-1] agg[key] = (float(np.mean(durations)), float(np.std(durations))) return agg