# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import triton
import triton.language as tl
from jaxtyping import Array, Bool, Float, Int
from ejkernel.callib import ejit
from ejkernel.utils import get_strides
PADDING_SEGMENT_ID = -1
@triton.jit
def padded_load(
ptrs,
offs_a,
offs_b,
PA0: tl.constexpr,
PA1: tl.constexpr,
LA0: tl.constexpr,
LA1: tl.constexpr,
):
"""Load data from memory with optional padding for boundary conditions.
Conditionally loads data with masking based on compile-time constants,
optimizing for different padding scenarios.
Args:
ptrs: Pointer to memory location
offs_a: Offsets for first dimension
offs_b: Offsets for second dimension
PA0: Whether first dimension needs padding check
PA1: Whether second dimension needs padding check
LA0: Actual length of first dimension
LA1: Actual length of second dimension
Returns:
Loaded tensor with zeros for out-of-bounds elements
"""
if PA0:
if PA1:
x = tl.load(
ptrs,
mask=(offs_a[:, None] < LA0) & (offs_b[None, :] < LA1),
other=0.0,
)
else:
x = tl.load(
ptrs,
mask=offs_a[:, None] < LA0,
other=0.0,
)
else:
if PA1:
x = tl.load(
ptrs,
mask=offs_b[None, :] < LA1,
other=0.0,
)
else:
x = tl.load(ptrs)
return x
[docs]def calc_bias_strides(
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None,
batch: int,
nheads_q: int,
QSeq: int,
KSeq: int,
) -> tuple[int, int, int]:
"""Calculate memory strides for bias tensor with broadcasting support.
Validates bias tensor dimensions and computes appropriate strides
for batch and head dimensions, supporting broadcasting when dimensions are 1.
Args:
bias: Optional bias tensor with shape [batch, heads, QSeq, KSeq]
batch: Expected batch size
nheads_q: Number of query attention heads
QSeq: Query sequence length
KSeq: Key sequence length
Returns:
tuple: (stride_bz, stride_bh, stride_bm) memory strides
Raises:
ValueError: If bias dimensions are incompatible with expected shapes
"""
if bias is not None:
if not hasattr(bias, "strides"):
strides = tuple(map(lambda x: x * bias.itemsize, get_strides(bias)))
else:
strides = bias.strides
if bias.shape[2] != QSeq or bias.shape[3] != KSeq:
raise ValueError(
f"Bias tensor has incompatible sequence dimensions. "
f"Expected shape [..., {QSeq}, {KSeq}], but got [..., {bias.shape[2]}, {bias.shape[3]}]. "
f"Full bias shape: {bias.shape}"
)
if bias.shape[0] == 1:
stride_bz = 0
elif bias.shape[0] == batch:
stride_bz = strides[0] // bias.itemsize
else:
raise ValueError(
f"Batch dimension mismatch in bias tensor. "
f"Expected either 1 (for broadcasting) or {batch} (batch size), "
f"but got {bias.shape[0]}. Consider reshaping your bias tensor."
)
if bias.shape[1] == 1:
stride_bh = 0
elif bias.shape[1] == nheads_q:
stride_bh = strides[1] // bias.itemsize
else:
raise ValueError(
f"Head dimension mismatch in bias tensor. "
f"Expected either 1 (for broadcasting) or {nheads_q} (number of heads), "
f"but got {bias.shape[1]}. Check that your bias tensor matches the model configuration."
)
stride_bm = strides[2] // bias.itemsize
else:
stride_bz, stride_bh, stride_bm = 0, 0, 0
return stride_bz, stride_bh, stride_bm
[docs]@ejit(static_argnames=["max_tokens"])
def attention_pack_with_static_shape(
x: Float[Array, "batch seq_len num_heads head_dim"],
attention_mask: Bool[Array, "batch seq_len"],
max_tokens: int | None = None,
) -> Float[Array, "1 max_tokens num_heads head_dim"]:
"""
Pack attention tensor by removing padding based on attention mask.
Uses a static maximum shape to be compatible with JIT.
"""
batch_size, seqlen = attention_mask.shape
num_heads, head_dim = x.shape[2], x.shape[3]
if max_tokens is None:
max_tokens = batch_size * seqlen
seqlens = jnp.sum(attention_mask, axis=1).astype(jnp.int32)
offsets = jnp.zeros((batch_size,), dtype=jnp.int32)
offsets = offsets.at[1:].set(jnp.cumsum(seqlens[:-1]))
packed = jnp.zeros((1, max_tokens, num_heads, head_dim), dtype=x.dtype)
batch_idx, pos_idx = jnp.meshgrid(jnp.arange(batch_size), jnp.arange(seqlen), indexing="ij")
batch_idx_flat = batch_idx.reshape(-1)
pos_idx_flat = pos_idx.reshape(-1)
valid_mask = pos_idx < seqlens[:, None]
target_idx = jnp.where(
valid_mask,
offsets[:, None] + pos_idx,
jnp.zeros_like(pos_idx),
)
target_idx_flat = target_idx.reshape(-1)
valid_mask_flat = valid_mask.reshape(-1)
def process_token(i, packed_acc):
b = batch_idx_flat[i]
p = pos_idx_flat[i]
t = target_idx_flat[i]
valid = valid_mask_flat[i]
packed_acc = jnp.where(valid, packed_acc.at[0, t].set(x[b, p]), packed_acc)
return packed_acc
packed = jax.lax.fori_loop(0, batch_size * seqlen, process_token, packed)
return packed
@triton.jit
def make_segment_mask(q_segment_ids, kv_segment_ids, transposed: tl.constexpr):
if transposed:
res = q_segment_ids[None, :] == kv_segment_ids[:, None]
else:
res = q_segment_ids[:, None] == kv_segment_ids[None, :]
return res
@triton.jit
def make_causal_mask(q_positions, kv_positions, transposed: tl.constexpr):
if transposed:
causal_mask = q_positions[None, :] >= kv_positions[:, None]
else:
causal_mask = q_positions[:, None] >= kv_positions[None, :]
return causal_mask
@triton.jit
def make_sliding_window_mask(
q_positions,
kv_positions,
window_left: tl.constexpr,
window_right: tl.constexpr,
transposed: tl.constexpr,
):
"""Create sliding window mask.
Args:
q_positions: Query token positions
kv_positions: KV token positions
window_left: How many positions to the left (past) to attend to
window_right: How many positions to the right (future) to attend to
transposed: Whether to transpose the mask
Returns:
Boolean mask where True means attend
"""
if transposed:
distance = q_positions[None, :] - kv_positions[:, None]
else:
distance = q_positions[:, None] - kv_positions[None, :]
in_window = (distance >= -window_right) & (distance <= window_left)
return in_window
[docs]def basic_attention_refrence(
q: Float[Array, "batch seq_len_q num_heads head_dim"],
k: Float[Array, "batch seq_len_k num_heads_kv head_dim"],
v: Float[Array, "batch seq_len_k num_heads_kv head_dim"],
attn_bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None,
query_padding_mask: Bool[Array, "batch seq_len_q"] | None = None,
key_padding_mask: Bool[Array, "batch seq_len_k"] | None = None,
dropout_prob: float = 0.0,
dropout_key: jax.Array | None = None,
window_size: tuple[int, int] = (-1, -1),
causal: bool = False,
softcap: float = 0.0,
) -> Float[Array, "batch seq_len_q num_heads head_dim"]:
"""Reference implementation of attention for testing and validation.
Provides a standard JAX implementation of scaled dot-product attention
with support for various masking options, useful for validating the
optimized Triton kernels.
Args:
q: Query tensor [batch, seq_len, num_heads, head_dim]
k: Key tensor [batch, seq_len_k, num_heads_kv, head_dim]
v: Value tensor [batch, seq_len_k, num_heads_kv, head_dim]
attn_bias: Optional attention bias tensor
query_padding_mask: Boolean mask for query positions
key_padding_mask: Boolean mask for key positions
dropout_prob: Dropout probability for attention weights
dropout_key: JAX random key for dropout
window_size: Local attention window (left, right)
causal: Whether to apply causal masking
softcap: Soft capping value for attention scores
Returns:
jnp.ndarray: Attention output with same shape as queries
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32)
QSeq, KSeq = q.shape[1], k.shape[1]
repeats = q.shape[2] // k.shape[2]
if repeats > 1:
k = jnp.repeat(k, repeats=repeats, axis=2)
v = jnp.repeat(v, repeats=repeats, axis=2)
d = q.shape[-1]
q_scaled = q / math.sqrt(d)
scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k)
if softcap is not None and softcap > 0:
scores = scores / softcap
scores = jnp.tanh(scores)
scores = scores * softcap
if key_padding_mask is not None:
key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq)
scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores)
if window_size is not None and (window_size[0] >= 0 or window_size[1] >= 0):
row_idx = jnp.arange(QSeq).reshape(-1, 1)
col_idx = jnp.arange(KSeq)
if key_padding_mask is None:
sk = KSeq
else:
sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1, 1)
if query_padding_mask is None:
sq = QSeq
else:
sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1, 1)
if window_size[0] < 0:
local_mask = col_idx > row_idx + sk - sq + window_size[1]
else:
if key_padding_mask is None:
sk_full = jnp.full_like(col_idx, KSeq)
else:
sk_full = sk
local_mask = jnp.logical_or(
col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full),
col_idx < row_idx + sk - sq - window_size[0],
)
scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores)
if attn_bias is not None:
scores = scores + attn_bias
attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype)
if window_size is not None and (window_size[0] >= 0 or window_size[1] >= 0):
all_masked = jnp.all(local_mask, axis=-1, keepdims=True)
attention = jnp.where(all_masked, 0.0, attention)
if query_padding_mask is not None:
query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1)
attention = jnp.where(query_mask, 0.0, attention)
dropout_scaling = 1.0 / (1 - dropout_prob)
if dropout_prob > 0 and dropout_key is not None:
dropout_mask = jax.random.bernoulli(dropout_key, p=1 - dropout_prob, shape=attention.shape)
attention_drop = attention * dropout_mask * dropout_scaling
else:
attention_drop = attention
output = jnp.einsum("bhts,bshd->bthd", attention_drop, v)
if query_padding_mask is not None:
query_mask_expanded = (~query_padding_mask).reshape(
query_padding_mask.shape[0],
QSeq,
1,
1,
)
output = jnp.where(query_mask_expanded, 0.0, output)
return output.astype(dtype_og)
[docs]@ejit(static_argnames=["max_tokens"])
def attention_pack_from_cu_static(
x: Float[Array, "batch seq_max num_heads head_dim"],
cum_seqlens: Int[Array, "batch_plus_one"],
max_tokens: int | None = None,
) -> Float[Array, "1 max_tokens num_heads head_dim"]:
"""
Packs variable-length batch using cum_seqlens into a single [1, T, H, D] tensor.
T can be any static upper bound (e.g., B*S_max). Only the first cum_seqlens[-1]
tokens will be written; the rest stay zero.
"""
B, S_max, H, D = x.shape
if max_tokens is None:
max_tokens = B * S_max
out = jnp.zeros((1, max_tokens, H, D), dtype=x.dtype)
def body_b(b, out_acc):
start = cum_seqlens[b]
end = cum_seqlens[b + 1]
L = end - start
def body_p(p, acc):
valid = p < L
dst = start + p
acc = jnp.where(valid, acc.at[0, dst].set(x[b, p]), acc)
return acc
out_acc = jax.lax.fori_loop(0, S_max, body_p, out_acc)
return out_acc
out = jax.lax.fori_loop(0, B, body_b, out)
return out
[docs]@ejit(static_argnames=["seqlen", "batch_size"])
def attention_unpack_with_static_shape(
x: Float[Array, "1 max_tokens num_heads head_dim"],
cum_seqlens: Int[Array, "batch_plus_one"],
batch_size: int,
seqlen: int,
) -> Float[Array, "batch seqlen num_heads head_dim"]:
"""
Unpack back into [B, seqlen, H, D] using cum_seqlens. The 'seqlen' is a static
padded max length; tokens past end are left as zeros.
"""
H, D = x.shape[2], x.shape[3]
out = jnp.zeros((batch_size, seqlen, H, D), dtype=x.dtype)
def body_b(b, out_acc):
start = cum_seqlens[b]
end = cum_seqlens[b + 1]
L = end - start
def body_p(p, acc):
valid = p < L
src = start + p
acc = jnp.where(valid, acc.at[b, p].set(x[0, src]), acc)
return acc
out_acc = jax.lax.fori_loop(0, seqlen, body_p, out_acc)
return out_acc
out = jax.lax.fori_loop(0, batch_size, body_b, out)
return out
[docs]def pad_to_block_size(
inputs: Sequence[Array] | None,
indexs: Array | None,
segment_ids: Array | None,
block_size: int,
pos_fill_value: int,
transposed_inputs: bool = False,
):
seq_len = indexs.shape[1]
padded_seq_len = (seq_len + block_size - 1) // block_size * block_size
pad_len = padded_seq_len - seq_len
if transposed_inputs:
inputs_axis = ((0, 0), (0, 0), (0, pad_len), (0, 0))
else:
inputs_axis = ((0, 0), (0, pad_len), (0, 0), (0, 0))
if pad_len > 0:
if inputs is not None:
inputs = [jnp.pad(e, inputs_axis) for e in inputs]
if indexs is not None:
indexs = jnp.pad(indexs, ((0, 0), (0, pad_len)), constant_values=pos_fill_value)
if segment_ids is not None:
segment_ids = jnp.pad(segment_ids, ((0, 0), (0, pad_len)), constant_values=PADDING_SEGMENT_ID)
return inputs, indexs, segment_ids