# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import jax
import jax.numpy as jnp
import triton
import triton.language as tl
from ejkernel.callib import triton_call
@triton.jit
def _cdiv(x, y):
return (x + y - 1) // y
@triton.jit
def _apply_softcap(scores, cap):
# stable tanh via exp
s = scores / cap
p1 = tl.exp(s)
p2 = tl.exp(-s)
return cap * (p1 - p2) / (p1 + p2)
@triton.jit
def _find_seq_idx(
query_start_len_ptr,
target_idx,
num_seqs,
BLOCK_Q: tl.constexpr,
use_q_block_mode: tl.constexpr,
):
left: tl.int32 = 0 # type:ignore
right = num_seqs
while left < right:
mid = (left + right) // 2
val = tl.load(query_start_len_ptr + mid)
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
if mid_val <= target_idx:
left = mid + 1
else:
right = mid
return left - 1
@triton.jit
def _unified_attention_2d(
query_ptr,
key_cache_ptr,
value_cache_ptr,
sink_ptr,
block_tables_ptr,
seq_lens_ptr,
alibi_slopes_ptr,
qq_bias_ptr,
query_start_len_ptr,
scale,
softcap,
block_table_stride,
query_stride_0,
query_stride_1,
output_stride_0,
output_stride_1,
qq_bias_stride_0,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
num_seqs,
out_ptr,
*,
num_query_heads: tl.constexpr,
num_queries_per_kv: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
TILE_SIZE: tl.constexpr,
HEAD_SIZE: tl.constexpr,
HEAD_SIZE_PADDED: tl.constexpr,
USE_ALIBI_SLOPES: tl.constexpr,
USE_QQ_BIAS: tl.constexpr,
USE_SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
BLOCK_Q: tl.constexpr,
BLOCK_M: tl.constexpr,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
num_seqs = num_seqs.to(tl.int32)
seq_idx = _find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]
dim_mask = (offs_d < HEAD_SIZE).to(tl.int1)
query_mask_0 = (query_pos < cur_batch_query_len).to(tl.int1)
query_mask_1 = (query_offset_1 < num_query_heads).to(tl.int1)
q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if USE_SINKS:
m_prev = tl.load(sink_ptr + query_offset_1, mask=query_mask_1, other=float("-inf")).to(tl.float32)
l_prev = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
else:
m_prev = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_prev = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
seq_len = tl.load(seq_lens_ptr + seq_idx).to(tl.int32)
context_len = seq_len - cur_batch_query_len
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0).to(tl.float32)
if USE_QQ_BIAS:
qq_bias_row_ptrs = qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
max_seq_prefix_len = (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + 1).to(
tl.int32
)
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
num_tiles = _cdiv(max_seq_prefix_len, TILE_SIZE)
tile_start = 0
tile_end = num_tiles
if SLIDING_WINDOW > 0:
qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum(qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, cur_batch_query_len - 1)
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
last_allowed_key = context_len + qpos_hi
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
for j in range(tile_start, tile_end):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE).to(tl.int64)
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None] & tile_mask[None, :], other=0.0)
v = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :] & tile_mask[:, None], other=0.0)
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
scores = scale * tl.dot(q, k)
if USE_SOFTCAP:
scores = _apply_softcap(scores, softcap)
scores = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, scores, float("-inf"))
if SLIDING_WINDOW > 0:
scores = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
scores,
float("-inf"),
)
if USE_ALIBI_SLOPES:
scores += alibi_slope[:, None] * (seq_offset - context_len).to(tl.float32)
if USE_QQ_BIAS:
key_rel_pos = seq_offset - context_len
is_query_key = (key_rel_pos >= 0) & (key_rel_pos < qq_bias_stride_0)
qqb = tl.load(qq_bias_row_ptrs + key_rel_pos[None, :], mask=is_query_key[None, :], other=0.0)
scores += qqb.to(tl.float32)
m_curr = tl.maximum(m_prev, tl.max(scores, axis=1))
m_curr = tl.where(m_curr > float("-inf"), m_curr, 0.0)
p = tl.exp(scores - m_curr[:, None])
l_curr = tl.sum(p, axis=1)
alpha = tl.exp(m_prev - m_curr)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
l_prev = l_prev * alpha + l_curr
m_prev = m_curr
l_prev = tl.maximum(l_prev, 1e-6)
acc = acc / l_prev[:, None]
output_offset = (
query_offset_0[:, None] * output_stride_0 + query_offset_1[:, None] * output_stride_1 + offs_d[None, :]
)
tl.store(out_ptr + output_offset, acc, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None])
@triton.jit
def _unified_attention_3d(
query_ptr,
key_cache_ptr,
value_cache_ptr,
sink_ptr,
block_tables_ptr,
seq_lens_ptr,
alibi_slopes_ptr,
qq_bias_ptr,
query_start_len_ptr,
scale,
softcap,
block_table_stride,
query_stride_0,
query_stride_1,
qq_bias_stride_0,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
num_seqs,
segm_output_ptr,
segm_max_ptr,
segm_expsum_ptr,
*,
num_query_heads: tl.constexpr,
num_queries_per_kv: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
TILE_SIZE: tl.constexpr,
HEAD_SIZE: tl.constexpr,
HEAD_SIZE_PADDED: tl.constexpr,
USE_ALIBI_SLOPES: tl.constexpr,
USE_QQ_BIAS: tl.constexpr,
USE_SOFTCAP: tl.constexpr,
USE_SINKS: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
BLOCK_Q: tl.constexpr,
BLOCK_M: tl.constexpr,
NUM_SEGMENTS_PER_SEQ: tl.constexpr,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)
num_seqs = num_seqs.to(tl.int32)
seq_idx = _find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
seq_len = tl.load(seq_lens_ptr + seq_idx).to(tl.int32)
tiles_per_segment = _cdiv(seq_len, NUM_SEGMENTS_PER_SEQ * TILE_SIZE)
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
return
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]
dim_mask = (offs_d < HEAD_SIZE).to(tl.int1)
query_mask_0 = (query_pos < cur_batch_query_len).to(tl.int1)
query_mask_1 = (query_offset_1 < num_query_heads).to(tl.int1)
q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
if USE_SINKS and segm_idx == 0:
m_prev = tl.load(sink_ptr + query_offset_1, mask=query_mask_1, other=float("-inf")).to(tl.float32)
l_prev = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
else:
m_prev = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_prev = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
context_len = seq_len - cur_batch_query_len
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0).to(tl.float32)
if USE_QQ_BIAS:
qq_bias_row_ptrs = qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
max_seq_prefix_len = (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + 1).to(
tl.int32
)
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
num_tiles = _cdiv(max_seq_prefix_len, TILE_SIZE)
segm_begin = segm_idx * tiles_per_segment
segm_end = tl.minimum((segm_idx + 1) * tiles_per_segment, num_tiles)
for j in range(segm_begin, segm_end):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE).to(tl.int64)
v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None] & tile_mask[None, :], other=0.0)
v = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :] & tile_mask[:, None], other=0.0)
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
scores = scale * tl.dot(q, k)
if USE_SOFTCAP:
scores = _apply_softcap(scores, softcap)
scores = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, scores, float("-inf"))
if SLIDING_WINDOW > 0:
scores = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
scores,
float("-inf"),
)
if USE_ALIBI_SLOPES:
scores += alibi_slope[:, None] * (seq_offset - context_len).to(tl.float32)
if USE_QQ_BIAS:
key_rel_pos = seq_offset - context_len
is_query_key = (key_rel_pos >= 0) & (key_rel_pos < qq_bias_stride_0)
qqb = tl.load(qq_bias_row_ptrs + key_rel_pos[None, :], mask=is_query_key[None, :], other=0.0)
scores += qqb.to(tl.float32)
m_curr = tl.maximum(m_prev, tl.max(scores, axis=1))
m_curr = tl.where(m_curr > float("-inf"), m_curr, 0.0)
p = tl.exp(scores - m_curr[:, None])
l_curr = tl.sum(p, axis=1)
alpha = tl.exp(m_prev - m_curr)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
l_prev = l_prev * alpha + l_curr
m_prev = m_curr
segm_output_offset = (
query_offset_0[:, None].to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ segm_idx * HEAD_SIZE_PADDED
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
)
tl.store(
segm_output_ptr + segm_output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
segm_offset = (
query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_offset_1 * NUM_SEGMENTS_PER_SEQ
+ segm_idx
)
tl.store(segm_max_ptr + segm_offset, m_prev, mask=query_mask_0 & query_mask_1)
tl.store(segm_expsum_ptr + segm_offset, l_prev, mask=query_mask_0 & query_mask_1)
@triton.jit
def _reduce_segments(
segm_output_ptr,
segm_max_ptr,
segm_expsum_ptr,
seq_lens_ptr,
query_start_len_ptr,
num_seqs,
output_stride_0,
output_stride_1,
out_ptr,
*,
num_query_heads: tl.constexpr,
TILE_SIZE: tl.constexpr,
HEAD_SIZE: tl.constexpr,
HEAD_SIZE_PADDED: tl.constexpr,
BLOCK_Q: tl.constexpr,
NUM_SEGMENTS_PER_SEQ: tl.constexpr,
):
query_token_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
num_seqs = num_seqs.to(tl.int32)
seq_idx = _find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False)
seq_len = tl.load(seq_lens_ptr + seq_idx).to(tl.int32)
tiles_per_segment = _cdiv(seq_len, NUM_SEGMENTS_PER_SEQ * TILE_SIZE)
act_num_segments = _cdiv(seq_len, tiles_per_segment * TILE_SIZE)
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full([NUM_SEGMENTS_PER_SEQ], act_num_segments, tl.int32)
dim_mask = (tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE).to(tl.int1)
segm_offset = (
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_head_idx * NUM_SEGMENTS_PER_SEQ
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)
)
segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf"))
overall_max = tl.max(segm_max)
segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0)
segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
overall_expsum = tl.sum(segm_expsum)
segm_output_offset = (
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED
+ tl.arange(0, HEAD_SIZE_PADDED)[None, :]
)
segm_output = tl.load(segm_output_ptr + segm_output_offset, mask=segm_mask[:, None] & dim_mask[None, :], other=0.0)
segm_output *= tl.exp(segm_max - overall_max)[:, None]
acc_sum = tl.sum(segm_output, axis=0)
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
out_offset = query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED)
tl.store(out_ptr + out_offset, acc, mask=dim_mask)
[docs]def unified_attention_triton(
*,
queries: jax.Array,
key_cache: jax.Array,
value_cache: jax.Array,
block_tables: jax.Array,
kv_lens: jax.Array,
query_start_loc: jax.Array,
softmax_scale: float | None,
causal: bool,
sliding_window: int | None,
logits_soft_cap: float | None,
seq_threshold_3d: int | None,
num_par_softmax_segments: int | None,
alibi_slopes: jax.Array | None,
qq_bias: jax.Array | None,
attention_sink: jax.Array | None,
num_warps: int | None,
num_stages: int | None,
) -> jax.Array:
if not causal:
raise NotImplementedError("unified_attention_triton only supports causal attention.")
if queries.ndim != 3:
raise ValueError("queries must be rank-3: [total_tokens, num_q_heads, head_dim]")
if key_cache.ndim != 4 or value_cache.ndim != 4:
raise ValueError("key_cache/value_cache must be rank-4: [num_blocks, block_size, num_kv_heads, head_dim]")
if key_cache.shape != value_cache.shape:
raise ValueError("key_cache and value_cache must have the same shape")
if query_start_loc.dtype != jnp.int32 or kv_lens.dtype != jnp.int32 or block_tables.dtype != jnp.int32:
raise ValueError("query_start_loc/kv_lens/block_tables must be int32")
total_tokens, num_query_heads, head_size = map(int, queries.shape)
_num_blocks, block_size, num_kv_heads, head_size_kv = map(int, key_cache.shape)
if head_size_kv != head_size:
raise ValueError("head_dim mismatch between queries and KV cache")
if num_query_heads % num_kv_heads != 0:
raise ValueError("num_q_heads must be divisible by num_kv_heads (GQA)")
num_seqs, max_blocks_per_seq = map(int, block_tables.shape)
if query_start_loc.shape[0] != num_seqs + 1 or kv_lens.shape[0] != num_seqs:
raise ValueError("query_start_loc must be [num_seqs+1] and kv_lens must be [num_seqs]")
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_size)
use_alibi = alibi_slopes is not None
use_qq_bias = qq_bias is not None
use_sinks = attention_sink is not None
if use_alibi:
if alibi_slopes.shape != (num_query_heads,):
raise ValueError("alibi_slopes must have shape (num_q_heads,)")
alibi_slopes = alibi_slopes.astype(jnp.float32)
else:
alibi_slopes = jnp.zeros((1,), dtype=jnp.float32)
if use_qq_bias:
if qq_bias.ndim != 2 or qq_bias.shape[0] != qq_bias.shape[1]:
raise ValueError("qq_bias must be square [num_query_tokens, num_query_tokens]")
qq_bias = qq_bias.astype(jnp.float32)
qq_bias_stride_0 = int(qq_bias.shape[1])
else:
qq_bias = jnp.zeros((1, 1), dtype=jnp.float32)
qq_bias_stride_0 = 0
if use_sinks:
if attention_sink.shape != (num_query_heads,):
raise ValueError("attention_sink must have shape (num_q_heads,)")
attention_sink = attention_sink.astype(jnp.float32)
else:
attention_sink = jnp.zeros((1,), dtype=jnp.float32)
num_queries_per_kv = num_query_heads // num_kv_heads
block_m = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
block_q = block_m // num_queries_per_kv
total_num_q_blocks = total_tokens // block_q + num_seqs
tile_size_prefill = 32
tile_size_decode = 16
head_size_padded = triton.next_power_of_2(head_size)
# contiguous strides
q_stride_0 = num_query_heads * head_size
q_stride_1 = head_size
o_stride_0 = q_stride_0
o_stride_1 = q_stride_1
bt_stride = max_blocks_per_seq
k_stride_0 = block_size * num_kv_heads * head_size
k_stride_1 = num_kv_heads * head_size
k_stride_2 = head_size
k_stride_3 = 1
v_stride_0 = k_stride_0
v_stride_1 = k_stride_1
v_stride_2 = k_stride_2
v_stride_3 = 1
if sliding_window is None:
sliding_window_val = 0
else:
sliding_window_val = int(sliding_window)
if sliding_window_val <= 0:
sliding_window_val = 0
if logits_soft_cap is None:
logits_soft_cap_val = 0.0
else:
logits_soft_cap_val = float(logits_soft_cap)
# vLLM selection logic: use segmented 3D kernel for decode-only, small batches.
# We treat the batch as "decode-only" when each sequence has a single query token.
decode_only = total_tokens <= num_seqs
use_2d = (
seq_threshold_3d is None
or num_par_softmax_segments is None
or not decode_only
or num_seqs > int(seq_threshold_3d)
)
metaparams_2d = dict(
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
BLOCK_SIZE=block_size,
TILE_SIZE=tile_size_prefill,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=head_size_padded,
USE_ALIBI_SLOPES=bool(use_alibi),
USE_QQ_BIAS=bool(use_qq_bias),
USE_SOFTCAP=bool(logits_soft_cap_val > 0),
USE_SINKS=bool(use_sinks),
SLIDING_WINDOW=int(sliding_window_val),
BLOCK_Q=int(block_q),
BLOCK_M=int(block_m),
num_warps=num_warps,
num_stages=num_stages,
)
if use_2d:
(out,) = triton_call(
queries,
key_cache,
value_cache,
attention_sink,
block_tables,
kv_lens,
alibi_slopes,
qq_bias,
query_start_loc,
float(softmax_scale),
float(logits_soft_cap_val),
int(bt_stride),
int(q_stride_0),
int(q_stride_1),
int(o_stride_0),
int(o_stride_1),
int(qq_bias_stride_0),
int(k_stride_0),
int(k_stride_1),
int(k_stride_2),
int(k_stride_3),
int(v_stride_0),
int(v_stride_1),
int(v_stride_2),
int(v_stride_3),
int(num_seqs),
kernel=_unified_attention_2d,
out_shape=[jax.ShapeDtypeStruct(queries.shape, queries.dtype)],
grid=lambda META: (int(total_num_q_blocks), int(num_kv_heads)),
name="ejkernel::triton::unified_attention_2d",
**metaparams_2d,
)
return out
num_segments = int(num_par_softmax_segments)
segm_out_shape = jax.ShapeDtypeStruct((total_tokens, num_query_heads, num_segments, head_size_padded), jnp.float32)
segm_ml_shape = jax.ShapeDtypeStruct((total_tokens, num_query_heads, num_segments), jnp.float32)
metaparams_3d = dict(
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
BLOCK_SIZE=block_size,
TILE_SIZE=tile_size_decode,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=head_size_padded,
USE_ALIBI_SLOPES=bool(use_alibi),
USE_QQ_BIAS=bool(use_qq_bias),
USE_SOFTCAP=bool(logits_soft_cap_val > 0),
USE_SINKS=bool(use_sinks),
SLIDING_WINDOW=int(sliding_window_val),
BLOCK_Q=int(block_q),
BLOCK_M=int(block_m),
NUM_SEGMENTS_PER_SEQ=num_segments,
num_warps=num_warps,
num_stages=num_stages,
)
segm_output, segm_max, segm_expsum = triton_call(
queries,
key_cache,
value_cache,
attention_sink,
block_tables,
kv_lens,
alibi_slopes,
qq_bias,
query_start_loc,
float(softmax_scale),
float(logits_soft_cap_val),
int(bt_stride),
int(q_stride_0),
int(q_stride_1),
int(qq_bias_stride_0),
int(k_stride_0),
int(k_stride_1),
int(k_stride_2),
int(k_stride_3),
int(v_stride_0),
int(v_stride_1),
int(v_stride_2),
int(v_stride_3),
int(num_seqs),
kernel=_unified_attention_3d,
out_shape=(segm_out_shape, segm_ml_shape, segm_ml_shape),
grid=lambda META: (int(total_num_q_blocks), int(num_kv_heads), int(num_segments)),
name="ejkernel::triton::unified_attention_3d",
**metaparams_3d,
)
metaparams_reduce = dict(
num_query_heads=num_query_heads,
TILE_SIZE=tile_size_decode,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=head_size_padded,
BLOCK_Q=int(block_q),
NUM_SEGMENTS_PER_SEQ=num_segments,
num_warps=num_warps,
num_stages=num_stages,
)
(out,) = triton_call(
segm_output,
segm_max,
segm_expsum,
kv_lens,
query_start_loc,
int(num_seqs),
int(o_stride_0),
int(o_stride_1),
kernel=_reduce_segments,
out_shape=[jax.ShapeDtypeStruct(queries.shape, queries.dtype)],
grid=lambda META: (int(total_tokens), int(num_query_heads)),
name="ejkernel::triton::unified_attention_reduce_segments",
**metaparams_reduce,
)
return out