ejkernel.modules - High-Level Operations API#

Overview#

The ejkernel.modules package provides high-level, user-friendly interfaces for attention mechanisms and other tensor operations. These modules abstract away the complexity of kernel selection, platform detection, and configuration management, allowing you to focus on your model architecture.

Key Concepts#

Module vs Function API#

Each operation is available in two forms:

  1. Module Class (e.g., FlashAttention): For advanced use cases where you need fine-grained control over execution, configuration, or integration with the executor system.

  2. Function API (e.g., flash_attention): For simple, direct usage with sensible defaults and automatic optimization.

# Function API - simple and direct
from ejkernel.modules import flash_attention

output = flash_attention(query, key, value, causal=True)

# Module API - for advanced control
from ejkernel.modules import FlashAttention
from ejkernel.ops import Executor, ConfigSelectorChain, ConfigCache

executor = Executor(ConfigSelectorChain(cache=ConfigCache()))
attn_module = FlashAttention()
output = executor(attn_module, query=query, key=key, value=value, causal=True)

Automatic Platform Selection#

All modules automatically detect your hardware (GPU, TPU, or CPU) and select the optimal kernel implementation:

  • GPU: Triton kernels (if available) or XLA fallback

  • TPU: Pallas kernels optimized for TPU matrix units

  • CPU: XLA-based implementation

You can override this with the platform parameter:

output = flash_attention(query, key, value, platform="xla")  # Force XLA
output = flash_attention(query, key, value, platform="triton")  # Force Triton

Configuration and Autotuning#

Modules support automatic performance tuning. On first invocation with new input shapes, the system benchmarks multiple configurations and caches the optimal one:

# First call: benchmarks and caches optimal config
output = flash_attention(query, key, value)

# Subsequent calls with same shapes: uses cached config
output = flash_attention(query, key, value)  # Fast!

Attention Operations#

FlashAttention#

Memory-efficient exact attention with O(N) memory complexity instead of O(N²).

from ejkernel.modules import flash_attention

# Basic usage
output = flash_attention(query, key, value)

# With causal masking (for autoregressive models)
output = flash_attention(query, key, value, causal=True)

# With sliding window (local attention)
output = flash_attention(query, key, value, sliding_window=256)

# With attention bias
output = flash_attention(query, key, value, bias=attention_bias)

# Variable-length sequences (packed batching)
output = flash_attention(
    query, key, value,
    cum_seqlens_q=cu_seqlens_q,  # [0, 128, 300, 512]
    cum_seqlens_k=cu_seqlens_k,
)

# With MaskInfo for complex masking
from ejkernel.types import MaskInfo
mask_info = MaskInfo.from_segments(segment_ids)
output = flash_attention(query, key, value, mask_info=mask_info)

Parameters:

  • query: Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key: Key tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • value: Value tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • causal: Apply causal masking (default: False)

  • softmax_scale: Scaling factor (default: 1/sqrt(head_dim))

  • sliding_window: Window size for local attention int or (left, right)

  • dropout_prob: Dropout probability (default: 0.0)

  • bias: Optional attention bias tensor

  • mask_info: Optional MaskInfo for complex masking patterns

  • cum_seqlens_q, cum_seqlens_k: Cumulative sequence lengths for variable-length batching

  • logits_soft_cap: Soft cap for attention logits (Gemma-2 style)


RingAttention#

Distributed attention using ring communication topology for ultra-long sequences across multiple devices.

from ejkernel.modules import ring_attention
import jax

# Setup device mesh for distributed execution
devices = jax.devices()
mesh = jax.sharding.Mesh(devices, axis_names=('sp',))

# Execute ring attention
with mesh:
    output = ring_attention(
        query, key, value,
        axis_name='sp',  # Required: name of sequence-parallel axis
        causal=True,
    )

When to use:

  • Sequences too long to fit in single device memory

  • Multi-device training with sequence parallelism

  • Context lengths > 128K tokens

Parameters:

  • axis_name: Name of the axis for collective operations (required)

  • chunk_size: Optional chunk size for chunked causal attention

  • sliding_window: Window size for local attention

  • softmax_aux: Attention sink logits for long-context stability


NativeSparseAttention#

Sparse attention with explicit block index specification for maximum control over sparsity patterns.

from ejkernel.modules import native_sparse_attention

# With explicit block indices
output = native_sparse_attention(
    query, key, value,
    block_indices=block_indices,  # Which key blocks each query attends to
    block_counts=16,  # Number of key blocks per query
)

# With variable-length sequences
output = native_sparse_attention(
    query, key, value,
    cu_seqlens=cu_seqlens,
    block_indices=block_indices,
)

Parameters:

  • block_indices: Which key blocks each query block attends to

  • block_counts: Number of key blocks per query block

  • cu_seqlens: Cumulative sequence lengths for variable batching


PageAttention#

Paged KV-cache attention optimized for inference serving with dynamic batching.

from ejkernel.modules import page_attention

output = page_attention(
    query,           # [batch, 1, num_heads, head_dim] for decode
    key_cache,       # [num_blocks, block_size, num_kv_heads, head_dim]
    value_cache,     # [num_blocks, block_size, num_kv_heads, head_dim]
    context_lens,    # [batch] - length of context for each sequence
    block_tables,    # [batch, max_blocks] - page table mapping
)

When to use:

  • Inference serving with continuous batching

  • Memory-efficient KV-cache management

  • vLLM-style serving systems


RaggedPageAttention (v2 and v3)#

Advanced paged attention supporting variable-length sequences without padding.

from ejkernel.modules import ragged_page_attention_v3

output = ragged_page_attention_v3(
    query,
    key_pages,
    value_pages,
    cu_seqlens_q=cu_q,
    cu_seqlens_k=cu_k,
    page_indices=page_indices,
)

v3 improvements over v2:

  • Better memory efficiency

  • Improved softmax numerics

  • Support for attention sinks


ScaledDotProductAttention#

Standard scaled dot-product attention using XLA primitives. Good baseline and fallback.

from ejkernel.modules import scaled_dot_product_attention

output = scaled_dot_product_attention(
    query, key, value,
    attention_mask=mask,
    causal=True,
)

BlockSparseAttention#

Block-sparse attention with configurable sparsity patterns.

from ejkernel.modules import blocksparse_attention

output = blocksparse_attention(
    query, key, value,
    block_mask=block_mask,  # Which blocks can attend to which
    block_size=64,
)

GLAttention (Gated Linear Attention)#

Linear attention with gating for O(N) complexity.

from ejkernel.modules import gla_attention

output = gla_attention(
    query, key, value,
    gate=gate_values,  # Gating tensor
)

LightningAttention#

Layer-aware attention with decay factors for efficient long-range modeling.

from ejkernel.modules import lightning_attention

output = lightning_attention(
    query, key, value,
    decay=decay_factor,
)

RecurrentAttention#

Stateful recurrent attention for streaming inference.

from ejkernel.modules import recurrent_attention

output, new_state = recurrent_attention(
    query, key, value,
    state=previous_state,  # Carries information from previous steps
)

Other Operations#

GroupedMatmul#

Efficient batched matrix multiplication with variable group sizes.

from ejkernel.modules import grouped_matmul

output = grouped_matmul(
    inputs,     # [total_tokens, hidden]
    weights,    # [num_experts, hidden, output]
    group_ids,  # [total_tokens] - which expert each token uses
)

When to use:

  • Mixture of Experts (MoE) layers

  • Token routing to different experts

  • Batched operations with variable sizes


MeanPooling#

Sequence mean pooling with variable-length support.

from ejkernel.modules import mean_pooling

# With attention mask
pooled = mean_pooling(
    hidden_states,  # [batch, seq_len, hidden]
    attention_mask,  # [batch, seq_len]
)

# Result: [batch, hidden] - mean of valid tokens per sequence

Distributed Execution#

All operations support distributed execution via JAX’s shard_map:

from ejkernel.modules import flash_attention
from jax.sharding import Mesh, PartitionSpec as P
import jax

# Setup mesh
devices = jax.devices()
mesh = Mesh(devices.reshape(2, 4), axis_names=('dp', 'tp'))

# Distributed attention with sharding specs
output = flash_attention(
    query, key, value,
    mesh=mesh,
    in_specs=(
        P('dp', None, 'tp', None),  # query sharding
        P('dp', None, 'tp', None),  # key sharding
        P('dp', None, 'tp', None),  # value sharding
    ),
    out_specs=P('dp', None, 'tp', None),
)

Configuration Classes#

Each operation has a corresponding configuration class:

from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops import FwdParams, BwdParams

# Create custom configuration
config = FlashAttentionConfig(
    fwd_params=FwdParams(
        q_blocksize=128,
        kv_blocksize=256,
        num_warps=4,
        num_stages=2,
    ),
    bwd_params=BwdParams(
        q_blocksize=64,
        kv_blocksize=128,
    ),
    platform="triton",
    backend="gpu",
)

# Use custom config
output = flash_attention(query, key, value, cfg=config)

Common configuration parameters:

  • platform: "triton", "pallas", "xla", "cuda", or "auto"

  • backend: "gpu", "tpu", or "any"

  • fwd_params: Forward pass block sizes and tuning parameters

  • bwd_params: Backward pass parameters


Best Practices#

1. Let autotuning work for you#

First calls may be slower due to autotuning, but subsequent calls with the same shapes will be fast.

2. Use MaskInfo for complex masking#

Instead of manually creating attention masks, use MaskInfo for segment-aware attention:

from ejkernel.types import MaskInfo

# From segment IDs (packed sequences)
mask_info = MaskInfo.from_segments(segment_ids)

# From attention mask
mask_info = MaskInfo.from_attention_mask(attention_mask)

# Use with attention
output = flash_attention(query, key, value, mask_info=mask_info)

3. Choose the right attention variant#

Use Case

Recommended

Standard training

flash_attention

Inference serving

page_attention or ragged_page_attention_v3

Very long sequences (>128K)

ring_attention

Sparse patterns

native_sparse_attention or blocksparse_attention

Streaming inference

recurrent_attention

MoE layers

grouped_matmul

4. Profile before optimizing#

Use JAX profiling to identify bottlenecks before manually tuning configurations.