# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax import lax
from jax._src import dtypes
from ejkernel.callib import ejit
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
[docs]def cdiv(a, b):
assert b != 0
return (a + b - 1) // b
[docs]def get_dtype_bitwidth(dtype):
return dtypes.bit_width(dtype)
[docs]def get_dtype_packing(dtype):
return 32 // get_dtype_bitwidth(dtype)
[docs]def align_to(x, a):
return cdiv(x, a) * a
[docs]def merge_kv(k: jax.Array, v: jax.Array) -> jax.Array:
with jax.named_scope("rpa_v3_xla.merge_kv"):
assert k.shape == v.shape
assert k.dtype == v.dtype
max_num_tokens, actual_num_kv_heads, actual_head_dim = k.shape
kv_packing = get_dtype_packing(k.dtype)
actual_num_kv_heads_x2 = actual_num_kv_heads * 2
num_kv_heads_x2 = align_to(actual_num_kv_heads_x2, kv_packing)
head_dim = align_to(actual_head_dim, 128)
kv = jnp.pad(
jnp.concat([k, v], axis=-1).reshape(max_num_tokens, actual_num_kv_heads_x2, actual_head_dim),
(
(0, 0),
(0, num_kv_heads_x2 - actual_num_kv_heads_x2),
(0, head_dim - actual_head_dim),
),
constant_values=0,
).reshape(
max_num_tokens,
num_kv_heads_x2 // kv_packing,
kv_packing,
head_dim,
)
return kv
[docs]@ejit(
static_argnames=(
"softmax_scale",
"sliding_window",
"logits_soft_cap",
"mask_value",
"q_scale",
"k_scale",
"v_scale",
"chunk_prefill_size",
"num_kv_pages_per_block",
"num_queries_per_block",
"vmem_limit_bytes",
),
donate_argnums=(3,),
inline=True,
)
def ragged_paged_attention(
queries: jax.Array,
keys: jax.Array,
values: jax.Array,
kv_cache: jax.Array,
kv_lens: jax.Array,
block_tables: jax.Array,
query_start_loc: jax.Array,
distribution: jax.Array,
attention_sink: jax.Array | None = None,
*,
softmax_scale: float = 1.0,
sliding_window: int | None = None,
logits_soft_cap: float | None = None,
mask_value: float | None = DEFAULT_MASK_VALUE,
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
chunk_prefill_size: int | None = None,
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
) -> tuple[jax.Array, jax.Array]:
del chunk_prefill_size, vmem_limit_bytes
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
with jax.named_scope("rpa_v3_xla.validate"):
static_validate_inputs(
queries,
keys,
values,
kv_cache,
kv_lens,
block_tables,
query_start_loc,
distribution,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
mask_value=mask_value,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
with jax.named_scope("rpa_v3_xla.setup"):
actual_head_dim = queries.shape[2]
total_q = queries.shape[0]
actual_num_q_heads = queries.shape[1]
actual_num_kv_heads = keys.shape[1]
actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
(
_total_num_pages,
page_size,
num_kv_heads_x2_per_kv_packing,
kv_packing,
head_dim_padded,
) = kv_cache.shape
num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
max_num_seqs = kv_lens.shape[0]
pages_per_seq = block_tables.shape[0] // max_num_seqs
tokens_per_seq = pages_per_seq * page_size
# Block sizes for a generic, jittable implementation.
qblocks = 8 if num_queries_per_block is None else int(num_queries_per_block)
if num_kv_pages_per_block is None:
# Larger kvblocks reduces Python/XLA loop overhead in `kv_loop`.
if pages_per_seq >= 256:
kvblocks = 256
elif pages_per_seq >= 128:
kvblocks = 128
elif pages_per_seq >= 64:
kvblocks = 64
else:
kvblocks = max(1, pages_per_seq)
else:
kvblocks = max(1, min(pages_per_seq, int(num_kv_pages_per_block)))
kv_tokens_per_block = kvblocks * page_size
# Pad Q/K/V so any qblocks-sized dynamic_slice is in-bounds.
# This may read across sequence boundaries for the final partial block, but
# masked writes ensure correctness.
pad_q = qblocks - 1
if pad_q:
queries = jnp.concatenate(
[queries, jnp.zeros((pad_q, actual_num_q_heads, actual_head_dim), queries.dtype)],
axis=0,
)
keys = jnp.concatenate(
[keys, jnp.zeros((pad_q, actual_num_kv_heads, actual_head_dim), keys.dtype)],
axis=0,
)
values = jnp.concatenate(
[values, jnp.zeros((pad_q, actual_num_kv_heads, actual_head_dim), values.dtype)],
axis=0,
)
padded_total_q = queries.shape[0]
q_grouped = queries.reshape(
padded_total_q,
actual_num_kv_heads,
actual_num_q_heads_per_kv_head,
actual_head_dim,
)
merged_kv = merge_kv(keys, values)
o_grouped = jnp.zeros_like(q_grouped)
arange_q = jnp.arange(qblocks, dtype=jnp.int32)
arange_kv = jnp.arange(kv_tokens_per_block, dtype=jnp.int32)
# Sliding-window KV start alignment; keep it simple and portable.
bkv_sz = page_size if sliding_window is not None else None
sinks_h = None
if attention_sink is not None:
sinks_h = attention_sink.reshape(actual_num_kv_heads, actual_num_q_heads_per_kv_head).astype(jnp.float32)
def _seq_body(seq_idx, carry):
o_acc, kv_cache_acc = carry
with jax.named_scope("rpa_v3_xla.seq_setup"):
q_start = query_start_loc[seq_idx]
q_end = query_start_loc[seq_idx + 1]
q_len = q_end - q_start
kv_len = kv_lens[seq_idx]
kv_start = jnp.int32(0)
if sliding_window is not None:
kv_start = jnp.maximum(kv_len - jnp.int32(sliding_window), 0)
kv_start = (kv_start // jnp.int32(bkv_sz)) * jnp.int32(bkv_sz)
write_start = kv_len - q_len
num_q_blocks = (q_len + qblocks - 1) // qblocks
with jax.named_scope("rpa_v3_xla.gather_pages"):
indices_start = seq_idx * pages_per_seq
page_indices = lax.dynamic_slice(block_tables, (indices_start,), (pages_per_seq,))
kv_pages = kv_cache_acc[page_indices]
kv_pages_flat = kv_pages.reshape(
tokens_per_seq,
num_kv_heads_x2_per_kv_packing,
kv_packing,
head_dim_padded,
)
def _update_kv_block(qb, kv_flat_pad):
with jax.named_scope("rpa_v3_xla.kv_update_block"):
q_off = qb * qblocks
dst = write_start + q_off
src = q_start + q_off
updates = lax.dynamic_slice(
merged_kv,
(src, 0, 0, 0),
(qblocks, num_kv_heads_x2_per_kv_packing, kv_packing, head_dim_padded),
)
existing = lax.dynamic_slice(
kv_flat_pad,
(dst, 0, 0, 0),
(qblocks, num_kv_heads_x2_per_kv_packing, kv_packing, head_dim_padded),
)
q_tok = q_off + arange_q
q_valid = q_tok < q_len
updates = jnp.where(q_valid[:, None, None, None], updates, existing)
return lax.dynamic_update_slice(kv_flat_pad, updates, (dst, 0, 0, 0))
with jax.named_scope("rpa_v3_xla.kv_update"):
kv_pages_flat_padded = jnp.concatenate(
[
kv_pages_flat,
jnp.zeros(
(
qblocks - 1,
num_kv_heads_x2_per_kv_packing,
kv_packing,
head_dim_padded,
),
dtype=kv_pages_flat.dtype,
),
],
axis=0,
)
kv_pages_flat_padded = lax.fori_loop(0, num_q_blocks, _update_kv_block, kv_pages_flat_padded)
kv_pages_flat = kv_pages_flat_padded[:tokens_per_seq]
kv_pages = kv_pages_flat.reshape(kv_pages.shape)
kv_cache_acc = kv_cache_acc.at[page_indices].set(kv_pages)
with jax.named_scope("rpa_v3_xla.attn_setup"):
# Pad pages axis to make kvblocks-sized slices safe.
kv_pages_padded = jnp.concatenate(
[
kv_pages,
jnp.zeros((kvblocks - 1, *kv_pages.shape[1:]), dtype=kv_pages.dtype),
],
axis=0,
)
num_kv_blocks = (kv_len + kv_tokens_per_block - 1) // kv_tokens_per_block
kv_block_start = kv_start // jnp.int32(kv_tokens_per_block)
def _process_query_block(qb, o_inner):
with jax.named_scope("rpa_v3_xla.q_block"):
q_off = qb * qblocks
q_global_start = q_start + q_off
q_block = lax.dynamic_slice(
q_grouped,
(q_global_start, 0, 0, 0),
(qblocks, actual_num_kv_heads, actual_num_q_heads_per_kv_head, actual_head_dim),
)
if q_scale is not None:
q_block = q_block / q_scale
if jnp.issubdtype(kv_pages.dtype, jnp.floating):
finfo = jnp.finfo(kv_pages.dtype)
q_block = jnp.clip(q_block, finfo.min, finfo.max)
q_block = q_block.astype(kv_pages.dtype)
q_tok = q_off + arange_q
q_valid = q_tok < q_len
q_pos = write_start + q_tok
init_acc = jnp.zeros(
(qblocks, actual_num_kv_heads, actual_num_q_heads_per_kv_head, actual_head_dim),
dtype=jnp.float32,
)
if sinks_h is not None:
init_m = jnp.broadcast_to(
sinks_h[None, :, :],
(qblocks, actual_num_kv_heads, actual_num_q_heads_per_kv_head),
)
init_l = jnp.ones_like(init_m)
else:
init_m = jnp.full(
(qblocks, actual_num_kv_heads, actual_num_q_heads_per_kv_head),
-jnp.inf,
dtype=jnp.float32,
)
init_l = jnp.zeros_like(init_m)
def _process_kv_block(kb, state):
with jax.named_scope("rpa_v3_xla.kv_block"):
acc, l, m = state
page_map_start = kb * kvblocks
kv_page_block = lax.dynamic_slice(
kv_pages_padded,
(page_map_start, 0, 0, 0, 0),
(kvblocks, page_size, num_kv_heads_x2_per_kv_packing, kv_packing, head_dim_padded),
)
kv_tok = kv_page_block.reshape(
kv_tokens_per_block,
num_kv_heads_x2_per_kv_packing,
kv_packing,
head_dim_padded,
)
kv_tok = kv_tok.reshape(kv_tokens_per_block, num_kv_heads_x2, head_dim_padded)
kv_tok = kv_tok[:, : actual_num_kv_heads * 2, :]
kv_tok = kv_tok.reshape(kv_tokens_per_block, actual_num_kv_heads, 2, head_dim_padded)
k_block = kv_tok[:, :, 0, :actual_head_dim]
v_block = kv_tok[:, :, 1, :actual_head_dim]
with jax.named_scope("logits"):
logits = jnp.einsum(
"bihd,kid->bihk",
q_block,
k_block,
preferred_element_type=jnp.float32,
)
logits = logits * softmax_scale
if k_scale is not None:
logits = logits * k_scale
if q_scale is not None:
logits = logits * q_scale
if logits_soft_cap is not None:
logits = logits_soft_cap * jnp.tanh(logits / logits_soft_cap)
with jax.named_scope("mask"):
kv_pos = kb * jnp.int32(kv_tokens_per_block) + arange_kv
kv_valid = jnp.logical_and(kv_pos >= kv_start, kv_pos < kv_len)
mask = jnp.logical_or(kv_pos[None, :] > q_pos[:, None], jnp.logical_not(kv_valid[None, :]))
if sliding_window is not None:
mask = jnp.logical_or(
mask,
kv_pos[None, :] <= (q_pos[:, None] - jnp.int32(sliding_window)),
)
mask = jnp.logical_or(mask, jnp.logical_not(q_valid)[:, None])
mask = mask[:, None, None, :]
logits = logits + jnp.where(mask, mask_value, 0.0)
with jax.named_scope("online_softmax"):
cur_max = jnp.max(logits, axis=-1)
new_m = jnp.maximum(m, cur_max)
exp_logits = jnp.exp(logits - new_m[..., None])
exp_logits = jnp.where(mask, 0.0, exp_logits)
rescale = jnp.exp(m - new_m)
l = rescale * l + jnp.sum(exp_logits, axis=-1)
acc = rescale[..., None] * acc + jnp.einsum(
"bihk,kid->bihd",
exp_logits,
v_block,
preferred_element_type=jnp.float32,
)
return acc, l, new_m
with jax.named_scope("rpa_v3_xla.kv_loop"):
acc, l, _m = lax.fori_loop(kv_block_start, num_kv_blocks, _process_kv_block, (init_acc, init_l, init_m))
l = jnp.maximum(l, 1e-6)
out_block = (acc / l[..., None]).astype(queries.dtype)
if v_scale is not None:
out_block = out_block * v_scale
existing = lax.dynamic_slice(
o_inner,
(q_global_start, 0, 0, 0),
(qblocks, actual_num_kv_heads, actual_num_q_heads_per_kv_head, actual_head_dim),
)
out_block = jnp.where(q_valid[:, None, None, None], out_block, existing)
return lax.dynamic_update_slice(o_inner, out_block, (q_global_start, 0, 0, 0))
with jax.named_scope("rpa_v3_xla.q_loop"):
o_acc = lax.fori_loop(0, num_q_blocks, _process_query_block, o_acc)
return o_acc, kv_cache_acc
num_seqs = distribution[2]
with jax.named_scope("rpa_v3_xla.seq_loop"):
o_grouped, kv_cache = lax.fori_loop(0, num_seqs, _seq_body, (o_grouped, kv_cache))
with jax.named_scope("rpa_v3_xla.finalize"):
out = o_grouped.reshape(padded_total_q, actual_num_q_heads, actual_head_dim)[:total_q]
return out, kv_cache