ejKernel Project Overview#
Executive Summary#
ejKernel is a sophisticated, high-performance kernel library for JAX that provides multi-backend support for various deep learning operations, with a particular focus on efficient attention mechanisms. The project demonstrates advanced software engineering practices with a modular architecture designed for extensibility and performance.
Project Information#
Name: ejKernel (EasyDeL JAX Kernels)
Version: 0.0.1
Author: Erfan Zare Chavoshi
License: Apache License 2.0
Python Support: 3.11 - 3.13
Primary Dependencies: JAX, Triton, Pallas, jaxtyping, beartype
Core Objectives#
Multi-Platform Support: Provide optimized kernel implementations for GPU (NVIDIA/AMD), TPU, and CPU
Performance Optimization: Automatic kernel selection and configuration tuning for optimal performance
Extensibility: Easy addition of new kernel implementations through registry system
Type Safety: Comprehensive type annotations with runtime validation
Developer-Friendly: Clean API with sensible defaults and progressive disclosure of advanced features
Architecture Overview#
ejkernel/
├── kernels/ # Core kernel implementations
│ ├── _triton/ # Triton GPU kernels
│ ├── _pallas/ # Pallas TPU kernels
│ ├── _xla/ # XLA CPU/fallback kernels
│ └── _cuda/ # Native CUDA kernels
├── modules/ # High-level operation modules
│ └── operations/ # Wrapped kernels with auto-selection
├── ops/ # Kernel execution framework
│ ├── config/ # Configuration management
│ ├── core/ # Base kernel classes
│ ├── execution/ # Execution orchestration
│ └── utils/ # Utilities and helpers
├── xla_utils/ # XLA-specific utilities
└── callib/ # Calibration library
Key Features#
1. Multi-Backend Kernel Registry#
Automatic Platform Detection: Seamlessly selects optimal implementation based on hardware
Priority-based Selection: Configurable kernel selection with fallback mechanisms
Signature Validation: Ensures consistency across implementations
2. Configuration Management Hierarchy#
Override → Overlay → Cache → Persistent → Autotune → Heuristics
In-memory and persistent caching for optimal configurations
Sophisticated autotuning with backward pass validation
3. Attention Mechanism Zoo#
Flash Attention v2: Memory-efficient O(N) attention with causal masking, dropout, sliding windows
Page Attention: Optimized for KV-cache in inference scenarios
Ring Attention: Distributed attention for sequence parallelism
Block Sparse Attention: Efficient sparse patterns for long-context processing
GLA (Gated Linear Attention): Linear complexity attention alternative
Lightning Attention: Layer-dependent decay attention mechanism
MLA (Multi-head Latent Attention): Efficient latent attention implementation
Ragged Attention: Variable-length sequence support
4. Advanced Operations#
Recurrent Kernels: Optimized RNN-like operations with custom gradients
Mean Pooling: Variable-length sequence pooling with proper masking
Grouped Matrix Multiplication: Efficient batched matrix operations
Native Sparse Operations: Block-sparse matrix computations
5. Developer Experience#
Full Type Hints: jaxtyping annotations for better IDE support
Modular Architecture: Easy to extend with new kernel implementations
Comprehensive Testing: Extensive test coverage with XLA vs Triton comparisons
Automatic Differentiation: Custom VJP rules for efficient gradients
Profiling Integration: Built-in support for JAX profiling tools
Platform Support Matrix#
Algorithm |
Triton GPU |
Pallas TPU |
XLA (CPU/Fallback) |
CUDA |
|---|---|---|---|---|
Flash Attention v2 |
✅ |
✅ |
✅ |
🚧 |
Page Attention |
✅ |
✅ |
✅ |
🚧 |
Ring Attention |
✅ |
✅ |
✅ |
🚧 |
Native Sparse |
✅ |
❌ |
✅ |
🚧 |
GLA |
✅ |
🚧 |
✅ |
❌ |
Lightning Attention |
✅ |
❌ |
✅ |
🚧 |
MLA |
✅ |
🚧 |
❌ |
❌ |
Ragged Page Attention |
✅ |
✅ |
✅ |
🚧 |
Recurrent |
✅ |
🚧 |
✅ |
🚧 |
Mean Pooling |
✅ |
🚧 |
✅ |
🚧 |
Grouped MatMul |
🚧 |
✅ |
✅ |
🚧 |
✅ = Implemented and optimized 🚧 = Under development ❌ = Not yet implemented
Design Principles#
Convention over Configuration: Sensible defaults everywhere with optional overrides
Progressive Disclosure: Simple API for common cases, advanced features when needed
Fail Gracefully: Multiple fallback layers ensure reliability
Optimize Lazily: Cache results, autotune on demand
Type Everything: Static and runtime validation for correctness
Usage Example#
import jax
import jax.numpy as jnp
from ejkernel.modules import FlashAttention, create_default_executor
# Initialize
executor = create_default_executor()
attention = FlashAttention()
# Create inputs
batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
key = jax.random.PRNGKey(0)
q = k = v = jax.random.normal(key, (batch, seq_len, num_heads, head_dim))
# Execute attention with automatic optimization
output = executor(
attention, q, k, v,
causal=True,
dropout_prob=0.1,
logits_soft_cap=30.0 # Gemma-2 style soft capping
)
Environment Variables#
# Autotuning
EJKERNEL_AUTOTUNE_POLICY=autotune|heuristics # Default: autotune
EJKERNEL_LOG_AUTOTUNE=1 # Log candidate testing
# Profiling
EJKERNEL_OPS_STAMP=hash|json|none # Default: none
EJKERNEL_OPS_RECORD=1 # Enable invocation recording
# Triton-specific
EJKERNEL_TRITON_SMEM_LIMIT=101376 # Shared memory limit (bytes)
Integration Points#
The library integrates seamlessly with:
JAX Ecosystem: Full compatibility with JAX transformations
EasyDeL Framework: Parent framework for JAX deep learning
JAX Profiling Tools: XProf, TensorBoard integration
Distributed Training: shard_map support for model parallelism
Future Roadmap#
Near-term#
Flash Attention 3 implementation
Flash Decoding for optimized inference
Quantized attention (INT8/INT4)
Fused LayerNorm + Attention kernels
Long-term#
Sliding Window Attention with Sinks
Differential Attention with learnable patterns
Mixture of Attention mechanisms
Speculative decoding kernels
Continuous batching support
Conclusion#
ejKernel represents a production-quality implementation of high-performance machine learning kernels with a focus on:
Modularity: Clean separation of concerns
Performance: State-of-the-art optimizations
Reliability: Multiple fallback mechanisms
Extensibility: Easy to add new features
Maintainability: Well-documented and tested
The project serves as an excellent foundation for building high-performance deep learning systems in JAX.