Source code for ejkernel.modules.operations.ragged_page_attention_v2

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Ragged Page Attention module with automatic optimization.

This module implements ragged page attention, combining the benefits of both
ragged (variable-length) sequence processing and paged KV cache management.
This approach is particularly efficient for serving scenarios where sequences
have variable lengths and KV cache is organized in fixed-size pages.

Ragged page attention addresses key challenges in LLM inference:
    - Variable-length sequences without padding overhead
    - Efficient memory management through paged KV cache
    - Dynamic batching with different sequence lengths
    - Memory sharing for beam search and prefix caching

Key Concepts:
    Ragged Layout: Sequences are concatenated without padding, with start
        locations tracking where each sequence begins
    Pages: Fixed-size blocks holding portions of KV cache
    Block Tables: Mapping from logical sequence positions to physical pages

The combination provides:
    - Zero padding overhead (ragged layout)
    - Flexible memory allocation (paged cache)
    - Efficient batching of variable-length sequences
    - Support for dynamic sequence management

Memory Layout:
    Queries: [total_tokens, num_heads, head_dim] (ragged, no padding)
    KV Cache: [num_pages, page_size, num_heads, head_dim] (paged)

Mathematical Foundation:
    For token i in sequence s:
        start_idx = query_start_loc[s]
        end_idx = query_start_loc[s + 1]
        output[i] = attention(Q[start_idx:end_idx], K[pages[s]], V[pages[s]])

This is the most memory-efficient attention variant for serving workloads.
"""

from __future__ import annotations

import os
from typing import Literal

from jax import numpy as jnp
from jax import shard_map
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, DTypeLike, Float, Int

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache

from ..base import detect_platform
from .configs import RaggedPageAttentionv2Config


def _xla_block_candidates_v2(inv: Invocation[RaggedPageAttentionv2Config, Array]) -> list[RaggedPageAttentionv2Config]:
    """Generate power-of-2 XLA configs with larger blocks."""
    try:
        queries = inv.kwargs["queries"]
        block_tables = inv.kwargs["block_tables"]
    except KeyError:
        return []

    total_tokens = int(getattr(queries, "shape", (0,))[0] or 0)
    pages_per_seq = int(getattr(block_tables, "shape", (0, 0))[1] or 0)
    if total_tokens <= 0 or pages_per_seq <= 0:
        return []

    kv_candidates = [k for k in (16, 32, 64) if k <= pages_per_seq]
    if not kv_candidates:
        kv_candidates = [min(16, pages_per_seq)]

    q_candidates = [q for q in (64, 128, 256) if q <= total_tokens]
    if not q_candidates:
        q_candidates = [min(64, total_tokens)]

    configs: list[RaggedPageAttentionv2Config] = []
    for kv in kv_candidates:
        for q in q_candidates:
            configs.append(
                RaggedPageAttentionv2Config(
                    num_kv_pages_per_block=kv,
                    num_queries_per_block=q,
                    num_warps=None,
                    num_stages=None,
                    platform="xla",
                    backend="any",
                )
            )
    return configs


[docs]class RaggedPageAttentionv2(Kernel[RaggedPageAttentionv2Config, Array]): """Ragged Page Attention with custom optimization logic. Combines ragged (variable-length) sequence processing with paged KV cache management for maximum memory efficiency in serving workloads. Features: - Zero padding overhead through ragged layout - Efficient paged KV cache management - Support for variable context lengths per sequence - Sliding window attention for long contexts - Logit soft capping for numerical stability - Attention sink mechanism for improved long-context performance - Configurable block sizes and memory limits - Multiple platform support (Triton/Pallas/CUDA/XLA) This implementation is particularly efficient for: - LLM serving with dynamic batching - Variable-length inference workloads - Memory-constrained deployment - Scenarios requiring efficient KV cache sharing The ragged layout eliminates padding overhead while paged cache enables flexible memory management and sharing. """ def __init__(self): """Initialize Ragged Page Attention module. Sets up the kernel with the operation identifier for registry lookup and configuration management. """ super().__init__(op_id="ragged_page_attention_v2")
[docs] def create_shard_map_wrapper( self, queries: Float[Array, "total_tokens num_q_heads head_dim"], kv_pages: Float[Array, "num_pages page_size num_combined_kv_heads head_dim"], context_lens: Int[Array, "num_seqs"], block_tables: Int[Array, "num_seqs pages_per_seq"], query_start_loc: Int[Array, "num_seqs_plus_one"], num_seqs: Array | int, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: DTypeLike = jnp.bfloat16, optimized: bool = False, sliding_window: int | None = None, softmax_aux: Float[Array, "num_q_heads"] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: RaggedPageAttentionv2Config | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False, ): """Create a shard_map wrapper specifically for ragged page attention. Ragged page attention handles variable-length sequences with paged KV cache, ideal for serving scenarios. Args: queries: Flattened queries [total_tokens, num_q_heads, head_dim] kv_pages: Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim] context_lens: Context lengths [num_seqs] block_tables: Block mapping [num_seqs, pages_per_seq] query_start_loc: Start locations [num_seqs + 1] num_seqs: Number of sequences All other args: Ragged page attention parameters to be fixed mesh: JAX device mesh in_specs: Input partition specs (for queries, kv_pages, context_lens, block_tables, query_start_loc, num_seqs, softmax_aux) out_specs: Output partition spec Returns: Tuple of (shard_map_fn, call_args) """ assert mesh is not None, "mesh must be provided for shard_map execution" assert in_specs is not None, "in_specs must be provided for shard_map execution" assert out_specs is not None, "out_specs must be provided for shard_map execution" def _wrapped_ragged_page_attn( queries: Float[Array, "total_tokens num_q_heads head_dim"], kv_pages: Float[Array, "num_pages page_size num_combined_kv_heads head_dim"], context_lens: Int[Array, "num_seqs"], block_tables: Int[Array, "num_seqs pages_per_seq"], query_start_loc: Int[Array, "num_seqs_plus_one"], num_seqs: Array | int, softmax_aux: Float[Array, "num_q_heads"] | None = None, ) -> Float[Array, "total_tokens num_q_heads head_dim"]: return self.run( queries=queries, kv_pages=kv_pages, context_lens=context_lens, block_tables=block_tables, query_start_loc=query_start_loc, num_seqs=num_seqs, softmax_scale=softmax_scale, logits_soft_cap=logits_soft_cap, compute_dtype=compute_dtype, optimized=optimized, sliding_window=sliding_window, softmax_aux=softmax_aux, mask_value=mask_value, vmem_limit_bytes=vmem_limit_bytes, platform=platform, cfg=cfg, ) call_args = ( queries, kv_pages, context_lens, block_tables, query_start_loc, num_seqs, softmax_aux, ) assert len(in_specs) == len(call_args), f"in_specs length {len(in_specs)} != call_args length {len(call_args)}" shard_map_fn = shard_map( _wrapped_ragged_page_attn, mesh=mesh, in_specs=in_specs, out_specs=out_specs, check_vma=check_vma, ) return shard_map_fn, call_args
[docs] def get_impl(self, cfg: RaggedPageAttentionv2Config): """Get kernel implementation from registry. Args: cfg: Configuration specifying platform and backend preferences Returns: Callable kernel implementation for ragged page attention Raises: ValueError: If no matching implementation is found for the configuration """ platform = detect_platform("ragged_page_attention_v2", cfg.platform) return kernel_registry.get("ragged_page_attention_v2", platform=platform, backend=cfg.backend)
[docs] def run( self, queries: Float[Array, "total_tokens num_q_heads head_dim"], kv_pages: Float[Array, "num_pages page_size num_combined_kv_heads head_dim"], context_lens: Int[Array, "num_seqs"], block_tables: Int[Array, "num_seqs pages_per_seq"], query_start_loc: Int[Array, "num_seqs_plus_one"], num_seqs: Array | int, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: DTypeLike = jnp.bfloat16, optimized: bool = False, sliding_window: int | None = None, softmax_aux: Float[Array, "num_q_heads"] | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, *, cfg: RaggedPageAttentionv2Config, ) -> Float[Array, "total_tokens num_q_heads head_dim"]: """Execute ragged page attention over variable-length sequences. Computes attention where queries are in ragged (concatenated) format and KV cache is organized in pages, providing maximum memory efficiency. Args: queries: Ragged query tensor [total_tokens, num_q_heads, head_dim] All sequences concatenated without padding kv_pages: Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim] Combined key-value cache in page format context_lens: Actual context length per sequence [num_seqs] block_tables: Page mapping [num_seqs, pages_per_seq] mapping logical pages to physical page indices query_start_loc: Start indices for each sequence in queries [num_seqs + 1] query_start_loc[i] to query_start_loc[i+1] defines sequence i num_seqs: Number of sequences in the batch softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) logits_soft_cap: Optional soft cap to bound attention logits compute_dtype: Data type for computation (default: bfloat16) optimized: Use optimized kernel implementation sliding_window: Window size for local attention (None for full attention) softmax_aux: Optional attention sink logits for long-context handling mask_value: Value to use for masked positions (default: -inf) vmem_limit_bytes: Memory limit for vector memory in bytes (TPU-specific) platform: Optional platform override ("triton", "pallas", "cuda", "xla") cfg: Kernel configuration object containing num_kv_pages_per_block and num_queries_per_block Returns: Attention output [total_tokens, num_q_heads, head_dim] in ragged format Note: The ragged format eliminates all padding overhead. Combined with paged KV cache, this provides the most memory-efficient attention implementation for serving workloads with variable-length sequences. Example: >>> >>> query_start_loc = jnp.array([0, 10, 25]) >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, ... block_tables, query_start_loc, num_seqs=2 ... ) """ if platform is not None: cfg = RaggedPageAttentionv2Config( num_kv_pages_per_block=cfg.num_kv_pages_per_block, num_queries_per_block=cfg.num_queries_per_block, num_warps=cfg.num_warps, num_stages=cfg.num_stages, platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) return impl( queries=queries, kv_pages=kv_pages, context_lens=context_lens, block_tables=block_tables, query_start_loc=query_start_loc, num_seqs=num_seqs, softmax_scale=softmax_scale, logits_soft_cap=logits_soft_cap, compute_dtype=compute_dtype, optimized=optimized, sliding_window=sliding_window, softmax_aux=softmax_aux, mask_value=mask_value, num_kv_pages_per_block=cfg.num_kv_pages_per_block, num_queries_per_block=cfg.num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, num_warps=cfg.num_warps, num_stages=cfg.num_stages, )
[docs] def heuristic_cfg(self, inv: Invocation[RaggedPageAttentionv2Config, Array]) -> RaggedPageAttentionv2Config: """Provide default configuration optimized for ragged page attention. Args: inv: Invocation object containing arguments and metadata Returns: Default configuration with conservative block sizes suitable for typical ragged attention workloads with variable sequence lengths """ return RaggedPageAttentionv2Config( num_kv_pages_per_block=None, num_queries_per_block=None, num_warps=4, num_stages=1, platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[RaggedPageAttentionv2Config, Array]): """Generate candidate configurations for autotuning. Creates configurations optimized for ragged attention scenarios with various batch sizes and sequence lengths. Args: inv: Invocation object containing arguments and metadata Returns: List of candidate configurations to benchmark during autotuning Note: Ragged attention performance depends on the distribution of sequence lengths and the page size. Candidates are chosen to work well across common serving scenarios. """ block_configs = [ (None, None, None, None), (1, 64, None, None), (2, 128, None, None), ] candidates = [] for num_kv_pages, num_queries, num_warps, num_stages in block_configs: candidates.append( RaggedPageAttentionv2Config( num_kv_pages_per_block=num_kv_pages, num_queries_per_block=num_queries, num_warps=num_warps, num_stages=num_stages, platform="auto", backend="any", ) ) return candidates
[docs] def candidate_cfgs_gpu(self, inv: Invocation[RaggedPageAttentionv2Config, Array]): """Generate candidate configurations for autotuning on GPU (Triton). Heuristics: """ q = inv.kwargs["queries"] kv = inv.kwargs["kv_pages"] block_tables = inv.kwargs["block_tables"] _total_tokens, num_q_heads, head_dim = map(int, q.shape) page_size = int(kv.shape[1]) pages_per_seq = int(block_tables.shape[1]) combined_kv_heads = int(kv.shape[2]) assert combined_kv_heads % 2 == 0 num_kv_heads = combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 if head_dim <= 64: m_opts = [32, 64, 128] elif head_dim <= 128: m_opts = [32, 64, 128] elif head_dim <= 192: m_opts = [32, 64] else: m_opts = [32, 64] if page_size <= 16: p_opts = [2, 4, 8] elif page_size <= 32: p_opts = [1, 2, 4] elif page_size <= 64: p_opts = [1, 2] else: p_opts = [1] max_S_block = 256 p_opts = [p for p in p_opts if p * page_size <= max_S_block] p_opts = [p for p in p_opts if p <= pages_per_seq] if not p_opts: p_opts = [min(2, pages_per_seq)] def pick_warps_stages(block_m: int, npages: int) -> tuple[int, int]: if head_dim <= 64: warps = 2 if block_m <= 64 else 4 elif head_dim <= 128: warps = 4 if block_m <= 64 else 8 else: warps = 4 if block_m <= 64 else 8 if npages >= 4: stages = 4 elif npages == 2: stages = 3 if page_size <= 32 else 2 else: stages = 2 return warps, stages high_value: list[tuple[int, int, int | None, int | None]] = [] hv_core = [(64, 2), (128, 2)] if page_size <= 32 and pages_per_seq >= 4: hv_core += [(64, 4)] if page_size <= 16 and pages_per_seq >= 8: hv_core += [(128, 4), (64, 8)] if page_size >= 64: hv_core += [(64, 1), (128, 1)] if head_dim >= 160: hv_core += [(32, 2), (64, 1)] seen_hv = set() for m, p in hv_core: if p in p_opts and m in m_opts and (m, p) not in seen_hv: w, s = pick_warps_stages(m, p) high_value.append((m, p, w, s)) seen_hv.add((m, p)) grid: list[tuple[int, int, int | None, int | None]] = [] for m in m_opts: for p in p_opts: if (m, p) in seen_hv: continue w, s = pick_warps_stages(m, p) grid.append((m, p, w, s)) block_configs: list[tuple[int, int, int | None, int | None]] = [] seen = set() for tup in high_value + grid: m, p, w, s = tup if (m, p) in seen: continue seen.add((m, p)) block_configs.append((m, p, w, s)) max_candidates = 18 block_configs = block_configs[:max_candidates] candidates = [ # noqa RaggedPageAttentionv2Config( num_kv_pages_per_block=npages, num_queries_per_block=block_m, num_warps=warps, num_stages=stages, platform="triton", backend="gpu", ) for (block_m, npages, warps, stages) in block_configs ] return _xla_block_candidates_v2(inv) or [ RaggedPageAttentionv2Config( num_kv_pages_per_block=None, num_queries_per_block=None, num_warps=None, num_stages=None, platform="xla", backend="any", ) ]
[docs] def candidate_cfgs_tpu(self, inv: Invocation[RaggedPageAttentionv2Config, Array]): """Generate candidate configurations for autotuning on TPU (Pallas backend). Heuristics: - For small head_dim, larger BLOCK_M is fine (64-128). - For large head_dim (>=160), prefer smaller BLOCK_M (32-64). - More KV pages per block helps small page_size (<=32). - Constrain S_block = page_size * num_kv_pages_per_block <= 256 to keep tiles reasonable. """ try: queries = inv.kwargs["queries"] block_tables = inv.kwargs["block_tables"] except KeyError: return [] total_tokens = int(getattr(queries, "shape", (0,))[0] or 0) pages_per_seq = int(getattr(block_tables, "shape", (0, 0))[1] or 0) if total_tokens <= 0 or pages_per_seq <= 0: return [] kv_candidates = [k for k in (16, 32, 64) if k <= pages_per_seq] if not kv_candidates: kv_candidates = [min(16, pages_per_seq)] q_candidates = [q for q in (4, 8, 16, 32, 64) if q <= total_tokens] if not q_candidates: q_candidates = [min(4, total_tokens)] configs = [] for kv in kv_candidates: for q in q_candidates: configs.append( RaggedPageAttentionv2Config( num_kv_pages_per_block=kv, num_queries_per_block=q, num_warps=None, num_stages=None, platform="pallas", backend="tpu", ) ) return configs
candidate_cfgs_shard_map_tpu = candidate_cfgs_tpu candidate_cfgs_shard_map_gpu = candidate_cfgs_gpu
_ragged_page_attention_executor: Executor[RaggedPageAttentionv2Config, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy( allow_autotune=True, cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"), validate_backward=False, ), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("ragged-page-attention"), ) )
[docs]def ragged_page_attention_v2( queries: Float[Array, "total_tokens num_q_heads head_dim"], kv_pages: Float[Array, "num_pages page_size num_combined_kv_heads head_dim"], context_lens: Int[Array, "num_seqs"], block_tables: Int[Array, "num_seqs pages_per_seq"], query_start_loc: Int[Array, "num_seqs_plus_one"], num_seqs: Array | int, softmax_aux: Float[Array, "num_q_heads"] | None = None, /, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: DTypeLike = jnp.bfloat16, optimized: bool = False, sliding_window: int | None = None, mask_value: float | None = None, vmem_limit_bytes: int | None = None, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: RaggedPageAttentionv2Config | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None, ) -> Float[Array, "total_tokens num_q_heads head_dim"]: """Execute ragged page attention with automatic optimization. Ragged page attention efficiently handles variable-length sequences in a single batch using flattened token representation and page-based KV cache. Args: queries: Flattened query tensor [total_tokens, num_q_heads, head_dim] kv_pages: Paged KV cache [num_pages, page_size, num_combined_kv_heads, head_dim] context_lens: Context length per sequence [num_seqs] block_tables: Block mapping table [num_seqs, pages_per_seq] query_start_loc: Start locations for each sequence [num_seqs + 1] num_seqs: Number of sequences in the batch softmax_scale: Softmax scaling factor logits_soft_cap: Soft capping value for logits compute_dtype: Computation dtype (default: bfloat16) optimized: Use optimized implementation sliding_window: Sliding window size for local attention softmax_aux: Attention sink logits mask_value: Value for masked positions vmem_limit_bytes: Memory limit in bytes platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Optional config override (num_kv_pages_per_block and num_queries_per_block are set via cfg) mesh: JAX device mesh for shard_map execution (optional) in_specs: Input partition specs for shard_map (optional) out_specs: Output partition spec for shard_map (optional) Returns: Attention output [total_tokens, num_q_heads, head_dim] Example: >>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs ... ) >>> >>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs, sliding_window=256 ... ) >>> >>> >>> out = ragged_page_attention_v2( ... queries, kv_pages, context_lens, block_tables, ... query_start_loc, num_seqs, optimized=True, logits_soft_cap=50.0 ... ) >>> >>> >>> out = ragged_page_attention_v2(..., platform="triton") """ method = None if mesh is not None and in_specs is not None and out_specs is not None: method = "shard_map" return _ragged_page_attention_executor( RaggedPageAttentionv2(), queries=queries, kv_pages=kv_pages, context_lens=context_lens, block_tables=block_tables, query_start_loc=query_start_loc, num_seqs=num_seqs, softmax_scale=softmax_scale, logits_soft_cap=logits_soft_cap, compute_dtype=compute_dtype, optimized=optimized, sliding_window=sliding_window, softmax_aux=softmax_aux, mask_value=mask_value, vmem_limit_bytes=vmem_limit_bytes, platform=platform, method=method, mesh=mesh, in_specs=in_specs, out_specs=out_specs, _cfg=cfg, )