ejkernel.loggings#

Logging utilities for ejKernel with colored output and progress tracking.

This module provides enhanced logging capabilities including: - Colored console output with level-specific formatting - Lazy logger initialization for multi-process JAX environments - Progress tracking with ETAs and progress bars - JAX profiler integration with Perfetto support

The logging system automatically adjusts for distributed training scenarios, suppressing output from non-primary processes to avoid clutter.

Key Components:

LazyLogger: Lazy-initialized logger with JAX process awareness ProgressLogger: Terminal progress bars with ETA calculations ColorFormatter: ANSI color formatting for log messages Profiler utilities: JAX trace profiler with Perfetto support

Constants:

COLORS: ANSI color code mappings LEVEL_COLORS: Log level to color mappings

Example

>>> from ejkernel.loggings import get_logger, ProgressLogger
>>>
>>> logger = get_logger(__name__)
>>> logger.info("Starting training")
>>>
>>> with ProgressLogger("Training") as progress:
...     for i, batch in enumerate(batches):
...         progress.update(i, len(batches), f"Batch {i}")
class ejkernel.loggings.ColorFormatter(fmt=None, datefmt=None, style='%', validate=True, *, defaults=None)[source]#

Bases: Formatter

Custom formatter that adds colors and timestamps to log messages.

This formatter applies ANSI color codes based on log level and formats multi-line messages with proper indentation and timestamps.

format()[source]#

Formats a log record with colors and timestamps.

format(record: LogRecord) str[source]#

Format a log record with colors and timestamp.

Parameters

record – The log record to format.

Returns

Formatted string with ANSI color codes and timestamp.

class ejkernel.loggings.LazyLogger(name: str, level: int | None = None)[source]#

Bases: object

Lazy-initialized logger that defers creation until first use.

This logger automatically adjusts its level in distributed JAX environments, suppressing output from non-primary processes to avoid clutter. It provides colored output and lazy initialization to avoid JAX runtime issues.

name#

Logger name.

level#

Current logging level.

Example

>>> logger = LazyLogger("MyModule")
>>> logger.info("This message only appears on process 0")
property level: int#

Return the current logging level as an integer.

Returns

Logging level integer (e.g., 10 for DEBUG, 20 for INFO).

property name: str#

Return the logger name.

Returns

The name string provided at construction time.

class ejkernel.loggings.ProgressLogger(name: str = 'Progress', logger_instance: ejkernel.loggings.LazyLogger | None = None)[source]#

Bases: object

A progress logger that displays updating progress bars and messages.

This class provides a clean way to show progress for long-running operations with support for progress bars, ETAs, and streaming updates that overwrite the same line in the terminal.

name#

Logger name to use for fallback logging

use_tty#

Whether to use TTY features (auto-detected)

start_time#

Start time of the progress operation

_logger#

Underlying logger for fallback

Example

>>> progress = ProgressLogger("Training")
>>> for i in range(100):
...     progress.update(i, 100, f"Processing batch {i}")
...
>>> progress.complete("Training finished!")
complete(message: str | None = None, show_time: bool = True) None[source]#

Complete the progress and show final message.

Parameters
  • message – Optional completion message

  • show_time – Whether to show total elapsed time

update(current: int, total: int, message: str = '', bar_width: int = 20, show_eta: bool = True, extra_info: str = '') None[source]#

Update the progress display.

Parameters
  • current – Current progress value (0-based)

  • total – Total number of items

  • message – Message to display after the progress bar

  • bar_width – Width of the progress bar in characters

  • show_eta – Whether to show estimated time remaining

  • extra_info – Additional info to append at the end

update_simple(message: str) None[source]#

Update with a simple message without progress bar.

Parameters

message – Message to display

ejkernel.loggings.create_step_profiler(profile_path: str, start_step: int, duration_steps: int, enable_perfetto: bool) Callable[[int], None][source]#

Create a step-aware profiler that activates during a specific training window.

Creates a callback function that can be called at each training step to automatically start/stop JAX profiling at the specified step range.

Parameters
  • profile_path – Directory to store profiling results (trace files).

  • start_step – Step number to begin profiling (inclusive).

  • duration_steps – Number of steps to profile.

  • enable_perfetto – Whether to generate Perfetto UI links (primary process only).

Returns

Callback function that takes step number and manages profiler lifecycle. Call this function at each training step with the current step number.

Example

>>> profiler = create_step_profiler(
...     profile_path="./profiles",
...     start_step=100,
...     duration_steps=10,
...     enable_perfetto=True
... )
>>> for step in range(1000):
...     profiler(step)
...     train_step(batch)
ejkernel.loggings.extinguish_profiler(enable_perfetto: bool) None[source]#

Stop the profiler and handle Perfetto link generation.

Safely stops JAX tracing and finalizes trace files. When Perfetto is enabled, keeps output streams alive during the potentially long finalization process.

Parameters

enable_perfetto – Whether Perfetto links were enabled during profiling. Used to determine if output pulsing is needed during shutdown.

Note

This function should be called after ignite_profiler() to stop tracing and write results to disk.

ejkernel.loggings.get_logger(name: str, level: int | None = None) LazyLogger[source]#

Create a lazy logger that only initializes when first used.

This is the primary factory function for creating loggers in ejKernel. The logger defers initialization to avoid JAX runtime issues and automatically adjusts for distributed training scenarios.

Parameters
  • name – The name of the logger, typically the module name.

  • level – The logging level. Defaults to environment variable LOGGING_LEVEL_ED or “INFO”.

Returns

A lazy logger instance that initializes on first use.

Example

>>> logger = get_logger(__name__)
>>> logger.info("Module initialized")
ejkernel.loggings.ignite_profiler(profile_path: str, enable_perfetto: bool = False) None[source]#

Start the JAX profiler with optional Perfetto integration.

Begins tracing JAX operations for performance analysis. Trace files are written to the specified path and can be viewed in Perfetto UI.

Parameters
  • profile_path – Directory to store profiling results (trace files).

  • enable_perfetto – Whether to generate Perfetto UI links. Only enabled on primary process (process_index == 0) to avoid duplicate links.

Note

Call extinguish_profiler() to stop tracing and finalize results.

ejkernel.loggings.logger = <ejkernel.loggings.LazyLogger object>#

Module-level logger instance used internally by profiler utilities.