ejkernel.types - Masking and Type Utilities#
Overview#
The ejkernel.types package provides data structures and utilities for managing attention masks, segment IDs, and sequence metadata. The centerpiece is MaskInfo, a unified container that handles various mask representations and provides seamless conversion between them.
Core Concept: Mask Representations#
Attention mechanisms use masks to control which positions can attend to which. ejkernel supports multiple representations:
Representation |
Shape |
Description |
|---|---|---|
Attention Mask |
|
Pairwise mask: True = can attend |
Segment IDs |
|
Group tokens: same ID = can attend |
Cumulative Lengths |
|
Start/end positions for variable-length batching |
MaskInfo automatically converts between these representations as needed.
MaskInfo#
The MaskInfo dataclass is a smart container that holds mask information and lazily computes derived representations on demand.
Creating MaskInfo#
From Segment IDs#
Segment IDs group tokens that can attend to each other. Use non-negative integers for valid tokens and -1 for padding.
from ejkernel.types import MaskInfo
import jax.numpy as jnp
# Single segment (all tokens attend to all)
segment_ids = jnp.array([[0, 0, 0, 0, -1, -1]]) # 4 valid tokens, 2 padding
mask_info = MaskInfo.from_segments(segment_ids)
# Multiple segments (packed sequences)
segment_ids = jnp.array([[1, 1, 2, 2, 3, 3, 3, -1]]) # 3 sequences packed
mask_info = MaskInfo.from_segments(segment_ids)
# Cross-attention (different Q and KV)
q_segment_ids = jnp.array([[1, 1, 2, 2]]) # query segments
kv_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2]]) # key-value segments
mask_info = MaskInfo.from_segments(q_segment_ids, kv_segment_ids)
From Attention Mask#
For existing attention masks in various shapes:
# 2D padding mask (batch, seqlen) - common for transformers
padding_mask = jnp.array([[1, 1, 1, 0, 0], [1, 1, 0, 0, 0]]) # 1=valid, 0=padding
mask_info = MaskInfo.from_attention_mask(padding_mask)
# 3D pairwise mask (batch, qlen, kvlen)
pairwise_mask = jnp.tril(jnp.ones((1, 8, 8), dtype=jnp.bool_)) # causal
mask_info = MaskInfo.from_attention_mask(pairwise_mask)
# 4D multi-head mask (batch, heads, qlen, kvlen)
mh_mask = jnp.ones((2, 8, 64, 64), dtype=jnp.bool_)
mask_info = MaskInfo.from_attention_mask(mh_mask)
# PyTorch-style inverted mask (True = masked out)
inverted_mask = ~jnp.tril(jnp.ones((1, 8, 8), dtype=jnp.bool_))
mask_info = MaskInfo.from_attention_mask(inverted_mask, mask_is_valid=False)
From Cumulative Sequence Lengths#
For variable-length batching (FlashAttention-style):
# cu_seqlens format: [start_0, end_0, start_1, end_1, ...]
# Example: 3 sequences with lengths [3, 5, 4] packed to max_len=5
cu_seqlens_q = jnp.array([0, 3, 0, 5, 0, 4]) # start/end pairs
mask_info = MaskInfo.from_cu_seqlens(cu_seqlens_q, max_q_len=5)
# Cross-attention with different Q/KV lengths
cu_seqlens_kv = jnp.array([0, 6, 0, 8, 0, 7])
mask_info = MaskInfo.from_cu_seqlens(
cu_seqlens_q, max_q_len=5,
cu_seqlens_kv=cu_seqlens_kv, max_kv_len=8
)
Random Mask (for Testing)#
# Generate random attention pattern
mask_info = MaskInfo.from_random(
batch_size=2,
q_len=128,
kv_len=256, # None for self-attention
sparsity=0.5, # 50% masked out
seed=42
)
Dynamic Initialization#
For model implementations where mask format varies:
# From input_ids (creates all-ones mask)
mask_info = MaskInfo.dynamic_init(input_ids=input_ids)
# From attention_mask (handles 2D/3D/4D)
mask_info = MaskInfo.dynamic_init(attention_mask=attention_mask)
# Pass through existing MaskInfo
mask_info = MaskInfo.dynamic_init(mask_info=existing_mask_info)
Accessing Mask Properties#
Properties are computed lazily on first access:
mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2, -1]]))
# Shape information
mask_info.batch_size # 1
mask_info.q_len # 5
mask_info.kv_len # 5
mask_info.shape # (1, 5, 5)
# Mask representations
mask_info.attention_mask # (B, 1, Q, K) boolean mask
mask_info.q_segment_ids # (B, Q) segment IDs
mask_info.kv_segment_ids # (B, K) segment IDs
mask_info.cu_seqlens_q # Cumulative lengths for Q
mask_info.cu_seqlens_kv # Cumulative lengths for KV
# Derived values
mask_info.q_lens # Per-batch query lengths
mask_info.kv_lens # Per-batch key-value lengths
mask_info.q_attention_mask # 2D query validity mask
mask_info.q_position_ids # Position IDs computed from mask
# Attention bias (for score computation)
mask_info.bias # 0.0 for valid, -inf for masked
mask_info.create_bias(dtype=jnp.float16) # Custom dtype
Applying Mask Transformations#
MaskInfo provides chainable methods for modifying attention patterns:
Causal Masking#
# Standard causal (each position attends to itself and earlier)
causal_mask = mask_info.apply_causal()
# With offset (allow attending to future positions)
causal_mask = mask_info.apply_causal(offset=5)
# Per-batch offsets
offsets = jnp.array([0, 2, 4]) # different offset per batch
causal_mask = mask_info.apply_causal(offset=offsets)
Sliding Window#
# Symmetric window (attend to 256 positions left and right)
windowed = mask_info.apply_sliding_window(256)
# Asymmetric window (512 left, 0 right = causal local)
windowed = mask_info.apply_sliding_window((512, 0))
# Decode mode: slice mask for single-token generation
decode_mask = mask_info.apply_sliding_window(
256,
mode="decode",
index=current_position
)
# Prefill mode: slice to last N positions
prefill_mask = mask_info.apply_sliding_window(
256,
mode="prefill"
)
Chunked Attention#
# Split attention into fixed-size chunks with causal ordering
chunked = mask_info.apply_chunked(chunk_size=128)
# Each query only attends within its chunk, with causal constraint
Token Type IDs#
For segment-aware attention (e.g., BERT-style):
# Same token types can attend to each other
token_types = jnp.array([[1, 1, 2, 2, 0, 0]]) # 0 = padding
typed_mask = mask_info.apply_token_type_ids(token_types)
# Different modes for combining with existing mask
typed_mask = mask_info.apply_token_type_ids(
token_types,
combine="intersect", # AND with existing mask
zero_policy="both" # treat 0 as padding on both Q and KV
)
KV Length Limiting#
For inference with variable context:
# Mask out KV positions beyond per-example lengths
kv_lengths = jnp.array([100, 150, 80]) # valid KV length per batch
limited = mask_info.apply_kv_lengths(kv_lengths)
# With query windowing
limited = mask_info.apply_kv_lengths(
kv_lengths,
q_len=1, # keep only 1 query position
end_index=current_idx, # window ends at this position
sliding_window=256 # limit KV to last 256 positions
)
Conversion and Materialization#
# Force computation of specific representations
mask_info = mask_info.materialize_attention_mask()
mask_info = mask_info.materialize_segment_ids()
# Convert attention mask dtype
mask_info = mask_info.to_dtype(jnp.float16)
# Get Q/KV masks separately
q_mask, kv_mask, attention_mask = mask_info.get_qkv_masks()
# Compute or retrieve representations
q_pos, kv_pos = mask_info.get_or_compute_positions()
cu_q, cu_kv = mask_info.get_or_compute_qkv_cu_seqlens()
Distributed Sharding#
MaskInfo integrates with JAX’s sharding system:
from jax.sharding import Mesh
devices = jax.devices()
mesh = Mesh(devices.reshape(2, 4), axis_names=('dp', 'tp'))
# Get sharding specs for all mask components
shardings = mask_info.get_shardings(mesh=mesh)
# With sequence parallelism
shardings = mask_info.get_shardings(
mesh=mesh,
sequence_parallel=True
)
# Customize axis names
mask_info = MaskInfo.from_segments(
segment_ids,
batch_axis_name=('dp', 'fsdp'),
sequence_axis_name='sp',
qheads_axis_name='tp'
)
Visualization#
Debug attention patterns with ASCII/Unicode visualization:
mask_info.visualize() # Print to console
mask_info.visualize(
batch=0,
head=0,
block_size=32, # Aggregate into blocks
charset="unicode", # or "ascii"
show_segments=True, # Show segment IDs
return_str=True # Return string instead of printing
)
Output:
Attention mask (batch=0, head=0) block=(32x32) mask_shape=(1, 1, 128, 128)
01 01 02 02
============
01 ||████.. ||
01 ||████.. ||
02 || ████||
02 || ████||
============
Legend mask: full='██'/##, partial='░░'/.., empty=' '
Legend seg: left=Q block ID, top=KV block ID, PAD=-1, MIX=??
JAX Pytree Support#
MaskInfo is registered as a JAX pytree, so it works seamlessly with JAX transformations:
# Works with jit
@jax.jit
def process(mask_info, x):
return x * mask_info.q_attention_mask[..., None]
# Works with vmap
vmapped = jax.vmap(process)
# Works with grad (for mask-dependent operations)
Utility Functions#
Debug Mode#
Enable verbose tracing for debugging:
from ejkernel.types import set_debug_mode, get_debug_mode
set_debug_mode(True) # Enable debug output
# ... operations will print trace info ...
set_debug_mode(False) # Disable
if get_debug_mode():
print("Debug mode is on")
Best Practices#
1. Choose the Right Constructor#
Situation |
Constructor |
|---|---|
Packed sequences with segment labels |
|
Existing attention mask |
|
Variable-length batching |
|
Model forward with flexible input |
|
Testing |
|
2. Lazy Computation#
MaskInfo computes representations lazily. Access only what you need:
# Good: Only segment IDs computed
segment_q, segment_kv = mask_info.get_or_compute_segment_ids()
# Bad: Forces full mask materialization when not needed
full_mask = mask_info.attention_mask # Avoid if not needed
3. Chain Transformations#
# Chain methods for complex patterns
mask_info = (
MaskInfo.from_segments(segment_ids)
.apply_causal()
.apply_sliding_window(256)
)
4. Reuse MaskInfo#
MaskInfo caches computed values. Reuse instances when possible:
# Good: Reuse mask_info
for layer in layers:
output = layer(x, mask_info=mask_info)
# Bad: Recreate each time
for layer in layers:
output = layer(x, mask_info=MaskInfo.from_segments(ids))