# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified (paged) attention module with automatic platform selection.
This operation wraps the vLLM-style unified attention kernel implemented in
`ejkernel.kernels` and provides a high-level API consistent with other
`ejkernel.modules.operations` entry points.
The unified attention kernel targets inference workloads that use a paged KV
cache and ragged query packing (variable-length sequences without padding).
"""
from __future__ import annotations
from typing import Literal
from jaxtyping import Array, Float, Int32
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
ConfigCache,
ConfigSelectorChain,
Executor,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ..base import detect_platform
from .configs import UnifiedAttentionConfig
MIN_LAUNCH_GRID_SIZE_2D = 128
NUM_PAR_SOFTMAX_SEGMENTS = 16
_ARGUMENT_ORDER = (
"queries",
"key_cache",
"value_cache",
"kv_lens",
"block_tables",
"query_start_loc",
)
_ARGUMENT_INDEX = {name: idx for idx, name in enumerate(_ARGUMENT_ORDER)}
def _resolve_inv_arg(inv: Invocation, name: str):
"""Resolve an argument from an Invocation by name.
Looks up the argument first in kwargs, then falls back to positional args
based on the canonical argument ordering defined in _ARGUMENT_ORDER.
Args:
inv: The Invocation object containing args and kwargs.
name: The argument name to resolve.
Returns:
The resolved argument value.
Raises:
KeyError: If the argument is not found in kwargs and the positional
index is out of bounds.
"""
if name in inv.kwargs:
return inv.kwargs[name]
idx = _ARGUMENT_INDEX.get(name)
if idx is None or idx >= len(inv.args):
raise KeyError(name)
return inv.args[idx]
[docs]class UnifiedAttention(Kernel[UnifiedAttentionConfig, Array]):
"""vLLM-style unified attention over a paged KV cache (inference-only).
This kernel implements unified paged attention for inference serving workloads,
supporting both prefill and decode phases with a single kernel. It is designed
to work with vLLM-style paged KV caches where key-value tensors are stored in
fixed-size blocks that can be dynamically allocated and mapped per sequence.
The unified attention kernel automatically selects between different execution
strategies based on sequence lengths:
- For short sequences: Uses a 2D grid launch for better parallelism
- For long sequences: Uses a 3D grid launch with parallel softmax reduction
Features:
- Paged KV cache support with block tables for memory efficiency
- Ragged query packing (variable-length sequences without padding)
- Automatic 2D/3D grid selection based on sequence characteristics
- Support for GQA/MQA (grouped/multi-query attention)
- Optional sliding window attention
- Optional logits soft capping (Gemma-2 style)
- Optional ALiBi position biases
- Optional attention sinks for streaming inference
Example:
>>> kernel = UnifiedAttention()
>>> output = kernel.run(
... queries=packed_queries, # [total_tokens, num_q_heads, head_dim]
... key_cache=key_cache, # [num_blocks, block_size, num_kv_heads, head_dim]
... value_cache=value_cache, # [num_blocks, block_size, num_kv_heads, head_dim]
... kv_lens=context_lengths, # [num_seqs]
... block_tables=block_tables, # [num_seqs, max_blocks_per_seq]
... query_start_loc=cu_seqlens, # [num_seqs + 1]
... cfg=UnifiedAttentionConfig(),
... )
"""
def __init__(self):
"""Initialize the UnifiedAttention kernel."""
super().__init__(op_id="unified_attention")
[docs] def get_impl(self, cfg: UnifiedAttentionConfig):
"""Get the platform-specific implementation.
Args:
cfg: Configuration specifying platform and backend preferences.
Returns:
Callable implementation function from the kernel registry.
"""
platform = detect_platform("unified_attention", cfg.platform)
return kernel_registry.get("unified_attention", platform=platform, backend=cfg.backend)
[docs] def run(
self,
queries: Float[Array, "total_tokens num_q_heads head_dim"],
key_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
value_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
kv_lens: Int32[Array, "num_seqs"],
block_tables: Int32[Array, "num_seqs max_blocks_per_seq"],
query_start_loc: Int32[Array, "num_seqs_plus_1"],
*,
softmax_scale: float | None = None,
causal: bool = True,
sliding_window: int | None = None,
logits_soft_cap: float | None = None,
alibi_slopes: Float[Array, "num_q_heads"] | None = None,
qq_bias: Float[Array, "num_query_tokens num_query_tokens"] | None = None,
attention_sink: Float[Array, "num_q_heads"] | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: UnifiedAttentionConfig,
) -> Float[Array, "total_tokens num_q_heads head_dim"]:
"""Execute unified paged attention.
Args:
queries: Packed query tensor of shape [total_tokens, num_q_heads, head_dim].
Contains all query tokens from all sequences concatenated together.
key_cache: Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim].
Pre-allocated blocks storing key vectors for all sequences.
value_cache: Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim].
Pre-allocated blocks storing value vectors for all sequences.
kv_lens: Context lengths per sequence of shape [num_seqs].
Number of valid KV tokens for each sequence.
block_tables: Block index mapping of shape [num_seqs, max_blocks_per_seq].
Maps logical block indices to physical block indices in the cache.
query_start_loc: Cumulative query token counts of shape [num_seqs + 1].
query_start_loc[i] gives the starting token index for sequence i.
softmax_scale: Scaling factor for attention scores. If None, uses 1/sqrt(head_dim).
causal: Whether to apply causal masking. Default True.
sliding_window: Optional sliding window size for local attention.
If provided, each query only attends to the last `sliding_window` KV positions.
logits_soft_cap: Optional soft cap for attention logits (Gemma-2 style).
Applies tanh-based capping: logits = soft_cap * tanh(logits / soft_cap).
alibi_slopes: Optional ALiBi slopes per head of shape [num_q_heads].
Adds position-dependent bias: bias[i,j] = slope * (j - i).
qq_bias: Optional query-query bias of shape [num_query_tokens, num_query_tokens].
Added directly to attention logits between query positions.
attention_sink: Optional attention sink values per head of shape [num_q_heads].
Adds constant attention to the first token for streaming inference stability.
platform: Override platform selection. One of "triton", "pallas", "cuda", "xla", "auto".
cfg: Kernel configuration with tuning parameters.
Returns:
Output tensor of shape [total_tokens, num_q_heads, head_dim] with attention results.
"""
if platform is not None:
cfg = UnifiedAttentionConfig(
seq_threshold_3d=cfg.seq_threshold_3d,
num_par_softmax_segments=cfg.num_par_softmax_segments,
num_warps=cfg.num_warps,
num_stages=cfg.num_stages,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
return impl(
queries=queries,
key_cache=key_cache,
value_cache=value_cache,
kv_lens=kv_lens,
block_tables=block_tables,
query_start_loc=query_start_loc,
softmax_scale=softmax_scale,
causal=causal,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
seq_threshold_3d=cfg.seq_threshold_3d,
num_par_softmax_segments=cfg.num_par_softmax_segments,
alibi_slopes=alibi_slopes,
qq_bias=qq_bias,
attention_sink=attention_sink,
num_warps=cfg.num_warps,
num_stages=cfg.num_stages,
)
[docs] def heuristic_cfg(self, inv: Invocation[UnifiedAttentionConfig, Array]) -> UnifiedAttentionConfig:
"""Generate default configuration based on input characteristics.
Follows vLLM's decode kernel selection heuristic to determine the
sequence length threshold for switching between 2D and 3D grid launches.
Args:
inv: Invocation containing the input arguments and metadata.
Returns:
UnifiedAttentionConfig with heuristically determined parameters:
- seq_threshold_3d: Sequence length above which 3D grid is used
- num_par_softmax_segments: Number of parallel softmax reduction segments
"""
key_cache = _resolve_inv_arg(inv, "key_cache")
num_kv_heads = int(key_cache.shape[2])
seq_threshold_3d = MIN_LAUNCH_GRID_SIZE_2D // max(1, num_kv_heads)
return UnifiedAttentionConfig(
seq_threshold_3d=int(seq_threshold_3d),
num_par_softmax_segments=int(NUM_PAR_SOFTMAX_SEGMENTS),
num_warps=None,
num_stages=None,
platform="auto",
backend="any",
)
[docs] def candidate_cfgs(self, inv: Invocation[UnifiedAttentionConfig, Array]):
"""Return candidate configurations for autotuning.
This operation exposes the main tuning knobs directly via the config,
so autotuning is avoided by default to reduce overhead.
Args:
inv: Invocation containing the input arguments and metadata.
Returns:
Empty list (autotuning disabled for this kernel).
"""
return []
_unified_attention_executor: Executor[UnifiedAttentionConfig, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(allow_autotune=True, cache_miss_fallback="heuristics", validate_backward=False),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("unified-attention"),
)
)
[docs]def unified_attention(
queries: Float[Array, "total_tokens num_q_heads head_dim"],
key_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
value_cache: Float[Array, "num_blocks block_size num_kv_heads head_dim"],
kv_lens: Int32[Array, "num_seqs"],
block_tables: Int32[Array, "num_seqs max_blocks_per_seq"],
query_start_loc: Int32[Array, "num_seqs_plus_1"],
/,
*,
softmax_scale: float | None = None,
causal: bool = True,
sliding_window: int | None = None,
logits_soft_cap: float | None = None,
alibi_slopes: Float[Array, "num_q_heads"] | None = None,
qq_bias: Float[Array, "num_query_tokens num_query_tokens"] | None = None,
attention_sink: Float[Array, "num_q_heads"] | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: UnifiedAttentionConfig | None = None,
) -> Float[Array, "total_tokens num_q_heads head_dim"]:
"""Execute unified paged attention with automatic platform selection.
This is the main entry point for vLLM-style unified attention, suitable for
inference serving workloads with paged KV caches. It handles both prefill
and decode phases efficiently using a single unified kernel.
The function automatically selects the optimal platform (Triton, Pallas, XLA)
based on available hardware and applies heuristic-based configuration tuning.
Args:
queries: Packed query tensor of shape [total_tokens, num_q_heads, head_dim].
All query tokens from all sequences are concatenated together without padding.
key_cache: Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim].
Pre-allocated memory blocks storing key vectors. Blocks are shared across sequences
and mapped via block_tables.
value_cache: Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim].
Pre-allocated memory blocks storing value vectors, with same layout as key_cache.
kv_lens: Context lengths per sequence of shape [num_seqs].
Specifies how many KV tokens are valid for each sequence.
block_tables: Block index mapping of shape [num_seqs, max_blocks_per_seq].
Maps each sequence's logical block indices to physical block indices in the cache.
For sequence i, block_tables[i, j] gives the physical block index for logical block j.
query_start_loc: Cumulative query token counts of shape [num_seqs + 1].
Defines the boundaries of each sequence in the packed queries tensor.
Sequence i's queries span indices [query_start_loc[i], query_start_loc[i+1]).
softmax_scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim).
causal: Whether to apply causal masking. Default: True.
When True, each query can only attend to KV positions at or before its position.
sliding_window: Optional sliding window size for local attention.
If provided, each query only attends to the most recent `sliding_window` KV positions.
logits_soft_cap: Optional soft cap for attention logits (Gemma-2 style).
Applies: logits = soft_cap * tanh(logits / soft_cap) to prevent extreme values.
alibi_slopes: Optional ALiBi position bias slopes per head [num_q_heads].
Adds linear position-dependent bias to attention scores.
qq_bias: Optional query-query attention bias [num_query_tokens, num_query_tokens].
Directly added to attention logits for query-to-query interactions.
attention_sink: Optional attention sink values per head [num_q_heads].
Provides stable attention to the first token for streaming/infinite context.
platform: Override automatic platform selection. One of:
- "triton": Force Triton (GPU)
- "pallas": Force Pallas (TPU/GPU)
- "xla": Force XLA fallback
- "cuda": Force CUDA
- "auto": Automatic selection (default)
- None: Same as "auto"
cfg: Optional configuration override. If None, uses heuristic-based defaults.
Returns:
Output tensor of shape [total_tokens, num_q_heads, head_dim] containing the
attention results for all sequences, packed in the same order as queries.
Example:
>>> import jax.numpy as jnp
>>> from ejkernel.modules import unified_attention
>>>
>>> # Setup for 2 sequences with different lengths
>>> total_tokens = 5 # seq1 has 2 tokens, seq2 has 3 tokens
>>> queries = jnp.ones((total_tokens, 8, 64)) # 8 heads, dim 64
>>> key_cache = jnp.ones((100, 16, 2, 64)) # 100 blocks, block_size=16, 2 KV heads
>>> value_cache = jnp.ones((100, 16, 2, 64))
>>> kv_lens = jnp.array([32, 48]) # context lengths
>>> block_tables = jnp.array([[0, 1], [2, 3, 4]]) # block mappings
>>> query_start_loc = jnp.array([0, 2, 5]) # [0, 2) for seq1, [2, 5) for seq2
>>>
>>> output = unified_attention(
... queries, key_cache, value_cache,
... kv_lens, block_tables, query_start_loc,
... causal=True,
... )
Note:
This kernel is optimized for inference and does not support backward passes.
For training, use `flash_attention` instead.
"""
return _unified_attention_executor(
UnifiedAttention(),
queries=queries,
key_cache=key_cache,
value_cache=value_cache,
kv_lens=kv_lens,
block_tables=block_tables,
query_start_loc=query_start_loc,
softmax_scale=softmax_scale,
causal=causal,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
alibi_slopes=alibi_slopes,
qq_bias=qq_bias,
attention_sink=attention_sink,
platform=platform,
_cfg=cfg,
)