Utilities and Helper Functions Analysis#
Overview#
The utilities and helper functions in ejKernel provide essential infrastructure for device management, type conversions, sequence handling, and various computational helpers. These utilities are organized across multiple modules to support the core kernel operations.
Module Organization#
ejkernel/
├── xla_utils/
│ └── utils.py # XLA-specific utilities
├── ops/utils/
│ ├── fingerprint.py # Device fingerprinting and hashing
│ └── datacarrier.py # Configuration data structures
└── callib/
└── __init__.py # Calibration utilities
XLA Utilities#
Sequence Length Handling#
def prepare_lens(cu_seqlens: Int[Array, "batch + 1"]) -> Int[Array, "batch"]:
"""
Convert cumulative sequence lengths to individual lengths.
Args:
cu_seqlens: Cumulative lengths [0, len1, len1+len2, ...]
Returns:
Individual lengths [len1, len2, ...]
Example:
>>> cu_seqlens = jnp.array([0, 5, 12, 18])
>>> prepare_lens(cu_seqlens)
Array([5, 7, 6], dtype=int32)
"""
return cu_seqlens[1:] - cu_seqlens[:-1]
def prepare_lens_from_mask(mask: Bool[Array, "batch seq_len"]) -> Int[Array, "batch"]:
"""
Extract sequence lengths from boolean mask.
Args:
mask: Boolean mask where True indicates valid positions
Returns:
Sequence lengths per batch element
Example:
>>> mask = jnp.array([[1, 1, 1, 0], [1, 1, 0, 0]])
>>> prepare_lens_from_mask(mask)
Array([3, 2], dtype=int32)
"""
return mask.sum(axis=-1, dtype=jnp.int32)
def prepare_cu_seqlens_from_mask(mask: Bool[Array, "batch seq_len"],
out_dtype: jnp.dtype = jnp.int32) -> Int[Array, "batch + 1"]:
"""
Convert boolean mask to cumulative sequence lengths.
Args:
mask: Boolean mask
out_dtype: Output data type
Returns:
Cumulative lengths with leading 0
Example:
>>> mask = jnp.array([[1, 1, 1, 0], [1, 1, 0, 0]])
>>> prepare_cu_seqlens_from_mask(mask)
Array([0, 3, 5], dtype=int32)
"""
lens = prepare_lens_from_mask(mask)
cumsum_lens = lens.cumsum(axis=0, dtype=out_dtype)
return jnp.pad(cumsum_lens, (1, 0))
Segment ID Utilities#
Segment IDs are an efficient representation for attention masking, particularly on TPUs:
def segment_ids_to_mask(
segment_ids: Int[Array, "batch seq_len"] | tuple[Int[Array, "..."], Int[Array, "..."]],
dtype: jnp.dtype = jnp.bool_,
return_separate_masks: bool = False
) -> Bool[Array, "batch seq_len_q seq_len_k"] | tuple:
"""
Convert segment IDs to attention mask.
Segment ID conventions:
- -1 or 0: padding token
- Positive integers: segment identifiers
- Same segment ID = can attend
Args:
segment_ids: Either single array for self-attention or
tuple (q_segments, kv_segments) for cross-attention
dtype: Output mask dtype
return_separate_masks: Return (q_mask, kv_mask, attention_mask)
Returns:
Attention mask or tuple of masks
Example:
>>> # Self-attention
>>> segment_ids = jnp.array([[1, 1, 2, 0], [1, 2, 0, 0]])
>>> mask = segment_ids_to_mask(segment_ids)
>>> # mask[0, 0, 1] = True (same segment 1)
>>> # mask[0, 0, 2] = False (different segments 1 vs 2)
>>> # Cross-attention
>>> q_segments = jnp.array([[1, 1, 2, 0]])
>>> kv_segments = jnp.array([[1, 2, 2, 0]])
>>> mask = segment_ids_to_mask((q_segments, kv_segments))
"""
if isinstance(segment_ids, tuple):
q_segment_ids, kv_segment_ids = segment_ids
# Create masks for valid positions
q_mask = (q_segment_ids > 0).astype(dtype)
kv_mask = (kv_segment_ids > 0).astype(dtype)
# Create attention mask
q_seg = q_segment_ids[:, :, None]
kv_seg = kv_segment_ids[:, None, :]
attention_mask = (q_seg == kv_seg) & (q_seg > 0) & (kv_seg > 0)
attention_mask = attention_mask.astype(dtype)
else:
# Self-attention case
q_mask = (segment_ids > 0).astype(dtype)
kv_mask = q_mask
seg_q = segment_ids[:, :, None]
seg_kv = segment_ids[:, None, :]
attention_mask = (seg_q == seg_kv) & (seg_q > 0) & (seg_kv > 0)
attention_mask = attention_mask.astype(dtype)
if return_separate_masks:
return q_mask, kv_mask, attention_mask
return attention_mask
Mask to Segment ID Conversion#
def mask_to_segment_ids(
mask: ndarray,
per_head: bool = False
) -> tuple[ndarray, ndarray]:
"""
JIT-friendly conversion from attention mask to segment IDs.
Algorithm:
1. Pack each row/column into bit representation
2. Find equal rows/columns via bit comparison
3. Assign contiguous segment IDs
4. Mark padding as -1
Args:
mask: Attention mask with shape:
- (Q, K): 2D mask
- (B, Q, K): Batched mask
- (B, H, Q, K): Per-head mask
per_head: Process each head independently
Returns:
(q_segment_ids, kv_segment_ids) with appropriate shapes
Example:
>>> mask = jnp.array([
... [[1, 1, 0, 0],
... [1, 1, 1, 0],
... [0, 1, 1, 1],
... [0, 0, 1, 1]]
... ])
>>> q_seg, kv_seg = mask_to_segment_ids(mask)
>>> # q_seg[0] = [1, 2, 3, 4] (unique patterns)
>>> # kv_seg[0] = [1, 2, 3, 4] (matching patterns)
"""
def pack_bits(arr):
"""Pack boolean array into bit representation"""
# Reshape to chunks of 8 for packbits
pad_len = (8 - arr.shape[-1] % 8) % 8
padded = jnp.pad(arr, ((0, 0),) * (arr.ndim - 1) + ((0, pad_len),))
packed = jnp.packbits(padded, axis=-1)
return packed
def find_unique_patterns(packed_rows):
"""Find unique bit patterns and assign IDs"""
n_rows = packed_rows.shape[0]
segment_ids = jnp.zeros(n_rows, dtype=jnp.int32)
current_id = 1
for i in range(n_rows):
# Check if this pattern has been seen
is_new = True
for j in range(i):
if jnp.all(packed_rows[i] == packed_rows[j]):
segment_ids = segment_ids.at[i].set(segment_ids[j])
is_new = False
break
if is_new:
# Check if it's padding (all zeros)
if jnp.any(packed_rows[i]):
segment_ids = segment_ids.at[i].set(current_id)
current_id += 1
else:
segment_ids = segment_ids.at[i].set(-1)
return segment_ids
# Handle different input shapes
if mask.ndim == 2:
# Simple 2D case
q_packed = pack_bits(mask)
kv_packed = pack_bits(mask.T)
q_segment_ids = find_unique_patterns(q_packed)
kv_segment_ids = find_unique_patterns(kv_packed)
elif mask.ndim == 3:
# Batched case
batch_size = mask.shape[0]
q_segment_ids = []
kv_segment_ids = []
for b in range(batch_size):
q_packed = pack_bits(mask[b])
kv_packed = pack_bits(mask[b].T)
q_segment_ids.append(find_unique_patterns(q_packed))
kv_segment_ids.append(find_unique_patterns(kv_packed))
q_segment_ids = jnp.stack(q_segment_ids)
kv_segment_ids = jnp.stack(kv_segment_ids)
elif mask.ndim == 4 and per_head:
# Per-head processing
batch_size, num_heads = mask.shape[:2]
q_segment_ids = []
kv_segment_ids = []
for b in range(batch_size):
q_head_ids = []
kv_head_ids = []
for h in range(num_heads):
q_packed = pack_bits(mask[b, h])
kv_packed = pack_bits(mask[b, h].T)
q_head_ids.append(find_unique_patterns(q_packed))
kv_head_ids.append(find_unique_patterns(kv_packed))
q_segment_ids.append(jnp.stack(q_head_ids))
kv_segment_ids.append(jnp.stack(kv_head_ids))
q_segment_ids = jnp.stack(q_segment_ids)
kv_segment_ids = jnp.stack(kv_segment_ids)
return q_segment_ids, kv_segment_ids
Type Conversion Utilities#
def identity_dtype_convert(dtype: jnp.dtype) -> Callable:
"""
Create identity function with custom gradient dtype conversion.
This is useful for mixed-precision training where forward pass
uses one dtype but gradients need to be in another.
Args:
dtype: Target dtype for gradients
Returns:
Identity function with dtype conversion in backward
Example:
>>> # Ensure gradients match query dtype in mixed precision
>>> convert_fn = identity_dtype_convert(query.dtype)
>>> key = convert_fn(key.astype(jnp.float16))
>>> # Forward: key stays in float16
>>> # Backward: gradients converted to query.dtype
"""
@jax.custom_vjp
def identity_fn(x):
return x
def identity_fn_fwd(x):
return x, None
def identity_fn_bwd(res, g):
return (g.astype(dtype),)
identity_fn.defvjp(identity_fn_fwd, identity_fn_bwd)
return identity_fn
Device Fingerprinting Utilities#
Device Identification#
def device_fingerprint(dev: Device | None = None) -> str:
"""
Generate unique device identifier for caching.
Args:
dev: JAX device (defaults to first device)
Returns:
String identifier: 'device_kind|platform_version'
Examples:
'gpu|cuda_12.0'
'tpu|v4'
'cpu|'
"""
if dev is None:
dev = jax.devices()[0]
device_kind = dev.platform
platform_version = ""
if device_kind == "gpu":
# Get CUDA/ROCm version
try:
import subprocess
result = subprocess.run(
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
capture_output=True,
text=True
)
if result.returncode == 0:
platform_version = f"cuda_{result.stdout.strip()}"
except:
platform_version = "cuda_unknown"
elif device_kind == "tpu":
# Get TPU generation
try:
tpu_type = dev.device_kind
if "v2" in tpu_type:
platform_version = "v2"
elif "v3" in tpu_type:
platform_version = "v3"
elif "v4" in tpu_type:
platform_version = "v4"
else:
platform_version = "unknown"
except:
platform_version = "unknown"
return f"{device_kind}|{platform_version}"
def get_device_platform(dev: Device | None = None) -> str:
"""
Get device platform type.
Returns:
'gpu', 'tpu', 'cpu', or 'unknown'
"""
if dev is None:
dev = jax.devices()[0]
platform = dev.platform.lower()
if "gpu" in platform or "cuda" in platform or "rocm" in platform:
return "gpu"
elif "tpu" in platform:
return "tpu"
elif "cpu" in platform:
return "cpu"
else:
return "unknown"
Hashing Utilities#
def abstractify(pytree: Any) -> Any:
"""
Convert arrays to ShapeDtypeStruct for cache keys.
This creates a lightweight representation that can be hashed
without accessing array data.
Args:
pytree: Nested structure with arrays
Returns:
Same structure with arrays replaced by ShapeDtypeStruct
Example:
>>> arr = jnp.ones((2, 3))
>>> abstract = abstractify(arr)
>>> # abstract = ShapeDtypeStruct((2, 3), float32)
"""
def _abstractify_leaf(x):
if isinstance(x, (jnp.ndarray, np.ndarray)):
return jax.ShapeDtypeStruct(x.shape, x.dtype)
return x
return jax.tree_map(_abstractify_leaf, pytree)
def short_hash(obj: Any) -> str:
"""
Generate 16-character hash for any object.
Args:
obj: Object to hash
Returns:
16-character hex string
Example:
>>> config = {"block_size": 128, "num_warps": 4}
>>> short_hash(config)
'a1b2c3d4e5f6g7h8'
"""
json_str = stable_json(obj)
full_hash = hashlib.sha256(json_str.encode()).hexdigest()
return full_hash[:16]
def stable_json(obj: Any) -> str:
"""
Create deterministic JSON representation.
Handles special types:
- Functions: (module, name, source_file)
- functools.partial: (func, args, kwargs)
- Dataclasses: converted to dict
- Pydantic models: model_dump()
- Arrays: (shape, dtype) only
Args:
obj: Object to serialize
Returns:
Deterministic JSON string
"""
def _serialize(o):
if callable(o):
# Serialize function reference
return {
"_type": "function",
"module": o.__module__,
"name": o.__name__,
"file": inspect.getsourcefile(o)
}
elif isinstance(o, functools.partial):
# Serialize partial application
return {
"_type": "partial",
"func": _serialize(o.func),
"args": [_serialize(arg) for arg in o.args],
"kwargs": {k: _serialize(v) for k, v in o.keywords.items()}
}
elif dataclasses.is_dataclass(o):
# Convert dataclass to dict
return {
"_type": "dataclass",
"class": o.__class__.__name__,
"data": dataclasses.asdict(o)
}
elif hasattr(o, 'model_dump'):
# Pydantic model
return {
"_type": "pydantic",
"class": o.__class__.__name__,
"data": o.model_dump()
}
elif isinstance(o, (jnp.ndarray, np.ndarray)):
# Array metadata only
return {
"_type": "array",
"shape": list(o.shape),
"dtype": str(o.dtype)
}
elif isinstance(o, jax.ShapeDtypeStruct):
# Already abstracted
return {
"_type": "shape_dtype",
"shape": list(o.shape),
"dtype": str(o.dtype)
}
elif isinstance(o, (list, tuple)):
return [_serialize(item) for item in o]
elif isinstance(o, dict):
return {k: _serialize(v) for k, v in sorted(o.items())}
else:
# Primitive types
return o
serialized = _serialize(obj)
return json.dumps(serialized, sort_keys=True, separators=(',', ':'))
Configuration Data Carriers#
Forward Pass Parameters#
@dataclass
class FwdParams:
"""
Forward pass configuration parameters for kernels.
These parameters control block sizes and thread configurations
for optimal performance on specific hardware.
"""
# Block dimensions for matrix operations
blocksize_m: int | None = None # M dimension block size
blocksize_k: int | None = None # K dimension block size
blocksize_n: int | None = None # N dimension block size
# Attention-specific block sizes
q_blocksize: int | None = None # Query block size
kv_blocksize: int | None = None # Key/Value block size
# Thread block configuration
blocksize_heads: int | None = None # Heads per block
blocksize_keys: int | None = None # Keys per block
num_key_splits: int | None = None # Key splitting factor
# GPU thread configuration
num_warps: int | None = None # Warps per block (GPU)
num_stages: int | None = None # Pipeline stages (GPU)
def __hash__(self):
"""Custom hash for use in caching"""
return hash(tuple(
getattr(self, field.name)
for field in dataclasses.fields(self)
))
def to_dict(self) -> dict:
"""Convert to dictionary, excluding None values"""
return {k: v for k, v in dataclasses.asdict(self).items()
if v is not None}
@classmethod
def merge(cls, base: "FwdParams", override: "FwdParams") -> "FwdParams":
"""Merge two parameter sets, override takes precedence"""
base_dict = base.to_dict() if base else {}
override_dict = override.to_dict() if override else {}
return cls(**{**base_dict, **override_dict})
Backward Pass Parameters#
@dataclass
class BwdParams:
"""
Backward pass configuration parameters.
Typically uses smaller block sizes than forward pass
for better memory efficiency during gradient computation.
"""
blocksize_m: int | None = None
blocksize_k: int | None = None
blocksize_n: int | None = None
q_blocksize: int | None = None
kv_blocksize: int | None = None
num_warps: int | None = None
num_stages: int | None = None
def __hash__(self):
"""Custom hash for caching"""
return hash(tuple(
getattr(self, field.name)
for field in dataclasses.fields(self)
))
@classmethod
def from_fwd_params(cls, fwd: FwdParams, scale: float = 0.5) -> "BwdParams":
"""
Create backward params from forward params.
Typically uses smaller blocks for memory efficiency.
"""
return cls(
q_blocksize=int((fwd.q_blocksize or 128) * scale),
kv_blocksize=int((fwd.kv_blocksize or 128) * scale),
num_warps=fwd.num_warps,
num_stages=fwd.num_stages
)
Calibration Utilities#
Warp Configuration Selection#
def select_num_warps(
q_block: int,
kv_block: int,
head_dim: int,
dtype: jnp.dtype = jnp.float16
) -> int:
"""
Select optimal number of warps for GPU kernels.
Args:
q_block: Query block size
kv_block: Key/value block size
head_dim: Head dimension
dtype: Data type
Returns:
Number of warps (4, 8, or 16)
Algorithm:
- Estimate compute intensity
- Balance parallelism vs register pressure
- Consider occupancy limits
"""
# Estimate operations per block
flops = q_block * kv_block * head_dim * 2 # Multiply-add
# Estimate data movement
bytes_moved = (q_block + kv_block) * head_dim * np.dtype(dtype).itemsize
# Compute intensity
intensity = flops / bytes_moved
# High intensity = more warps for compute
# Low intensity = fewer warps to reduce conflicts
if intensity > 10:
num_warps = 8
elif intensity > 5:
num_warps = 4
else:
num_warps = 2
# Adjust for block size
total_threads = num_warps * 32
if q_block * kv_block < 1024:
num_warps = min(num_warps, 4)
elif q_block * kv_block > 16384:
num_warps = max(num_warps, 8)
return num_warps
def select_num_stages(smem_bytes: int) -> int:
"""
Select pipeline stages for GPU kernels.
Args:
smem_bytes: Shared memory usage
Returns:
Number of pipeline stages (1, 2, or 3)
Note:
More stages enable better latency hiding but
require more shared memory and registers.
"""
smem_limit = get_shared_memory_limit()
# Available memory for additional stages
available = smem_limit - smem_bytes
if available > smem_bytes * 2:
return 3 # Can afford 3 stages
elif available > smem_bytes:
return 2 # Can afford 2 stages
else:
return 1 # Single stage only
Array Manipulation Helpers#
Chunking and Padding#
def chunk_array(
arr: Array,
chunk_size: int,
axis: int = 0,
pad_value: float = 0.0
) -> Array:
"""
Chunk array along axis with padding if needed.
Args:
arr: Input array
chunk_size: Size of chunks
axis: Axis to chunk along
pad_value: Value for padding
Returns:
Chunked array with shape [..., num_chunks, chunk_size, ...]
Example:
>>> x = jnp.arange(10)
>>> chunk_array(x, 4)
Array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 0, 0]], dtype=int32)
"""
shape = arr.shape
axis_len = shape[axis]
# Calculate padding needed
num_chunks = (axis_len + chunk_size - 1) // chunk_size
pad_len = num_chunks * chunk_size - axis_len
if pad_len > 0:
# Pad array
pad_width = [(0, 0)] * arr.ndim
pad_width[axis] = (0, pad_len)
arr = jnp.pad(arr, pad_width, constant_values=pad_value)
# Reshape to chunks
new_shape = list(shape)
new_shape[axis:axis+1] = [num_chunks, chunk_size]
return arr.reshape(new_shape)
def unchunk_array(
arr: Array,
original_length: int,
axis: int = 0
) -> Array:
"""
Reverse chunking operation.
Args:
arr: Chunked array
original_length: Original length before chunking
axis: Chunked axis
Returns:
Unchunked array with original length
Example:
>>> chunked = jnp.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 0, 0]])
>>> unchunk_array(chunked, 10)
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
"""
# Flatten chunks
shape = list(arr.shape)
shape[axis:axis+2] = [-1]
arr = arr.reshape(shape)
# Trim to original length
indices = [slice(None)] * arr.ndim
indices[axis] = slice(original_length)
return arr[tuple(indices)]
Numerical Stability Helpers#
def log_sum_exp(x: Array, axis: int | None = None,
keepdims: bool = False) -> Array:
"""
Numerically stable log-sum-exp.
Args:
x: Input array
axis: Reduction axis
keepdims: Keep reduced dimensions
Returns:
log(sum(exp(x)))
Example:
>>> x = jnp.array([1000., 1001., 1002.])
>>> log_sum_exp(x) # Stable computation
1002.407
"""
x_max = jnp.max(x, axis=axis, keepdims=True)
exp_x = jnp.exp(x - x_max)
sum_exp = jnp.sum(exp_x, axis=axis, keepdims=keepdims)
lse = jnp.log(sum_exp)
if not keepdims and axis is not None:
x_max = jnp.squeeze(x_max, axis=axis)
return lse + x_max
def safe_divide(numerator: Array, denominator: Array,
min_denominator: float = 1e-10) -> Array:
"""
Safe division with small denominator handling.
Args:
numerator: Numerator array
denominator: Denominator array
min_denominator: Minimum denominator value
Returns:
numerator / max(denominator, min_denominator)
"""
safe_denom = jnp.maximum(denominator, min_denominator)
return numerator / safe_denom
Testing Utilities#
Array Comparison#
def assert_allclose(actual: Array, expected: Array,
rtol: float = 1e-5, atol: float = 1e-8,
equal_nan: bool = True):
"""
Assert arrays are close within tolerances.
Args:
actual: Actual array
expected: Expected array
rtol: Relative tolerance
atol: Absolute tolerance
equal_nan: Consider NaN values equal
Raises:
AssertionError: If arrays differ beyond tolerance
"""
if actual.shape != expected.shape:
raise AssertionError(f"Shape mismatch: {actual.shape} vs {expected.shape}")
if actual.dtype != expected.dtype:
print(f"Warning: dtype mismatch: {actual.dtype} vs {expected.dtype}")
close = jnp.allclose(actual, expected, rtol=rtol, atol=atol,
equal_nan=equal_nan)
if not close:
# Find maximum difference for debugging
diff = jnp.abs(actual - expected)
max_diff = jnp.max(diff)
max_idx = jnp.unravel_index(jnp.argmax(diff), diff.shape)
raise AssertionError(
f"Arrays not close\n"
f"Max difference: {max_diff} at index {max_idx}\n"
f"Actual: {actual[max_idx]}\n"
f"Expected: {expected[max_idx]}"
)
Mock Data Generation#
def generate_attention_inputs(
batch_size: int = 2,
seq_len: int = 128,
num_heads: int = 8,
head_dim: int = 64,
dtype: jnp.dtype = jnp.float32,
key: jax.random.PRNGKey | None = None
) -> dict[str, Array]:
"""
Generate random attention inputs for testing.
Returns:
Dictionary with query, key, value arrays
"""
if key is None:
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 3)
query = jax.random.normal(
keys[0], (batch_size, seq_len, num_heads, head_dim), dtype=dtype
) * 0.02
key_array = jax.random.normal(
keys[1], (batch_size, seq_len, num_heads, head_dim), dtype=dtype
) * 0.02
value = jax.random.normal(
keys[2], (batch_size, seq_len, num_heads, head_dim), dtype=dtype
) * 0.02
return {"query": query, "key": key_array, "value": value}
Performance Monitoring#
Timer Context Manager#
@contextmanager
def timer(name: str = "Operation"):
"""
Simple timer context manager.
Example:
>>> with timer("Flash Attention"):
... output = flash_attention(q, k, v)
Flash Attention took 0.123 seconds
"""
start = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start
print(f"{name} took {elapsed:.3f} seconds")
Memory Profiling#
def get_memory_usage() -> dict[str, int]:
"""
Get current memory usage on device.
Returns:
Dictionary with allocated and reserved memory in bytes
"""
device = jax.devices()[0]
if device.platform == "gpu":
# GPU memory tracking
stats = device.memory_stats()
return {
"allocated": stats.get("bytes_in_use", 0),
"reserved": stats.get("bytes_reserved", 0),
"limit": stats.get("bytes_limit", 0)
}
else:
# Basic memory info for other platforms
import psutil
process = psutil.Process()
return {
"allocated": process.memory_info().rss,
"reserved": 0,
"limit": psutil.virtual_memory().total
}
Conclusion#
The utilities and helper functions in ejKernel provide:
Device Management: Fingerprinting and platform detection
Data Conversion: Segment IDs, masks, and type conversions
Configuration: Parameter dataclasses with merging and hashing
Numerical Stability: Safe operations and stability helpers
Testing Support: Comparison, generation, and profiling utilities
These utilities form the foundation that enables the kernel system’s flexibility, performance, and reliability across different hardware platforms.