ejKernel Project Overview#

Executive Summary#

ejKernel is a sophisticated, high-performance kernel library for JAX that provides multi-backend support for various deep learning operations, with a particular focus on efficient attention mechanisms. The project demonstrates advanced software engineering practices with a modular architecture designed for extensibility and performance.

Project Information#

  • Name: ejKernel (EasyDeL JAX Kernels)

  • Version: 0.0.1

  • Author: Erfan Zare Chavoshi

  • License: Apache License 2.0

  • Python Support: 3.11 - 3.13

  • Primary Dependencies: JAX, Triton, Pallas, jaxtyping, beartype

Core Objectives#

  1. Multi-Platform Support: Provide optimized kernel implementations for GPU (NVIDIA/AMD), TPU, and CPU

  2. Performance Optimization: Automatic kernel selection and configuration tuning for optimal performance

  3. Extensibility: Easy addition of new kernel implementations through registry system

  4. Type Safety: Comprehensive type annotations with runtime validation

  5. Developer-Friendly: Clean API with sensible defaults and progressive disclosure of advanced features

Architecture Overview#

ejkernel/
├── kernels/           # Core kernel implementations
│   ├── _triton/      # Triton GPU kernels
│   ├── _pallas/      # Pallas TPU kernels
│   ├── _xla/         # XLA CPU/fallback kernels
│   └── _cuda/        # Native CUDA kernels
├── modules/          # High-level operation modules
│   └── operations/   # Wrapped kernels with auto-selection
├── ops/              # Kernel execution framework
│   ├── config/       # Configuration management
│   ├── core/         # Base kernel classes
│   ├── execution/    # Execution orchestration
│   └── utils/        # Utilities and helpers
├── xla_utils/        # XLA-specific utilities
└── callib/           # Calibration library

Key Features#

1. Multi-Backend Kernel Registry#

  • Automatic Platform Detection: Seamlessly selects optimal implementation based on hardware

  • Priority-based Selection: Configurable kernel selection with fallback mechanisms

  • Signature Validation: Ensures consistency across implementations

2. Configuration Management Hierarchy#

  • Override → Overlay → Cache → Persistent → Autotune → Heuristics

  • In-memory and persistent caching for optimal configurations

  • Sophisticated autotuning with backward pass validation

3. Attention Mechanism Zoo#

  • Flash Attention v2: Memory-efficient O(N) attention with causal masking, dropout, sliding windows

  • Page Attention: Optimized for KV-cache in inference scenarios

  • Ring Attention: Distributed attention for sequence parallelism

  • Block Sparse Attention: Efficient sparse patterns for long-context processing

  • GLA (Gated Linear Attention): Linear complexity attention alternative

  • Lightning Attention: Layer-dependent decay attention mechanism

  • MLA (Multi-head Latent Attention): Efficient latent attention implementation

  • Ragged Attention: Variable-length sequence support

4. Advanced Operations#

  • Recurrent Kernels: Optimized RNN-like operations with custom gradients

  • Mean Pooling: Variable-length sequence pooling with proper masking

  • Grouped Matrix Multiplication: Efficient batched matrix operations

  • Native Sparse Operations: Block-sparse matrix computations

5. Developer Experience#

  • Full Type Hints: jaxtyping annotations for better IDE support

  • Modular Architecture: Easy to extend with new kernel implementations

  • Comprehensive Testing: Extensive test coverage with XLA vs Triton comparisons

  • Automatic Differentiation: Custom VJP rules for efficient gradients

  • Profiling Integration: Built-in support for JAX profiling tools

Platform Support Matrix#

Algorithm

Triton GPU

Pallas TPU

XLA (CPU/Fallback)

CUDA

Flash Attention v2

🚧

Page Attention

🚧

Ring Attention

🚧

Native Sparse

🚧

GLA

🚧

Lightning Attention

🚧

MLA

🚧

Ragged Page Attention

🚧

Recurrent

🚧

🚧

Mean Pooling

🚧

🚧

Grouped MatMul

🚧

🚧

✅ = Implemented and optimized 🚧 = Under development ❌ = Not yet implemented

Design Principles#

  1. Convention over Configuration: Sensible defaults everywhere with optional overrides

  2. Progressive Disclosure: Simple API for common cases, advanced features when needed

  3. Fail Gracefully: Multiple fallback layers ensure reliability

  4. Optimize Lazily: Cache results, autotune on demand

  5. Type Everything: Static and runtime validation for correctness

Usage Example#

import jax
import jax.numpy as jnp
from ejkernel.modules import FlashAttention, create_default_executor

# Initialize
executor = create_default_executor()
attention = FlashAttention()

# Create inputs
batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
key = jax.random.PRNGKey(0)
q = k = v = jax.random.normal(key, (batch, seq_len, num_heads, head_dim))

# Execute attention with automatic optimization
output = executor(
    attention, q, k, v,
    causal=True,
    dropout_prob=0.1,
    logits_soft_cap=30.0  # Gemma-2 style soft capping
)

Environment Variables#

# Autotuning
EJKERNEL_AUTOTUNE_POLICY=autotune|heuristics  # Default: autotune
EJKERNEL_LOG_AUTOTUNE=1                        # Log candidate testing

# Profiling
EJKERNEL_OPS_STAMP=hash|json|none             # Default: none
EJKERNEL_OPS_RECORD=1                          # Enable invocation recording

# Triton-specific
EJKERNEL_TRITON_SMEM_LIMIT=101376             # Shared memory limit (bytes)

Integration Points#

The library integrates seamlessly with:

  • JAX Ecosystem: Full compatibility with JAX transformations

  • EasyDeL Framework: Parent framework for JAX deep learning

  • JAX Profiling Tools: XProf, TensorBoard integration

  • Distributed Training: shard_map support for model parallelism

Future Roadmap#

Near-term#

  • Flash Attention 3 implementation

  • Flash Decoding for optimized inference

  • Quantized attention (INT8/INT4)

  • Fused LayerNorm + Attention kernels

Long-term#

  • Sliding Window Attention with Sinks

  • Differential Attention with learnable patterns

  • Mixture of Attention mechanisms

  • Speculative decoding kernels

  • Continuous batching support

Conclusion#

ejKernel represents a production-quality implementation of high-performance machine learning kernels with a focus on:

  • Modularity: Clean separation of concerns

  • Performance: State-of-the-art optimizations

  • Reliability: Multiple fallback mechanisms

  • Extensibility: Easy to add new features

  • Maintainability: Well-documented and tested

The project serves as an excellent foundation for building high-performance deep learning systems in JAX.