# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ragged Decode Attention module with automatic optimization.
This module implements ragged decode attention, an efficient attention mechanism
optimized for inference scenarios with variable-length sequences in the decode phase.
Unlike standard attention which requires padded sequences, ragged attention processes
sequences with different lengths efficiently by using sequence start/end markers.
Ragged decode attention is particularly valuable for:
- Inference workloads with batched sequences of varying lengths
- Decoder-only models during generation
- Serving scenarios requiring efficient batching
- Situations where padding overhead is significant
The key innovation is using sequence_start and sequence_end arrays to define
valid attention ranges per sequence, eliminating the need for padding while
maintaining efficient vectorized computation.
Key Features:
- Efficient variable-length sequence handling without padding
- Support for sliding window attention for long contexts
- Optional logit soft capping for numerical stability
- Attention sink support for improved long-context performance
- Configurable block sizes for memory-compute tradeoffs
Mathematical Foundation:
For each query position i in sequence s:
output[i] = softmax(Q[i] @ K[start[s]:end[s]].T / scale) @ V[start[s]:end[s]]
Where start[s] and end[s] define the valid KV range for sequence s.
"""
from __future__ import annotations
import math
import os
from typing import Literal
from jax import numpy as jnp
from jax import shard_map
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, Float, Int
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
ConfigCache,
ConfigSelectorChain,
Executor,
FwdParams,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ..base import detect_platform
from .configs import RaggedDecodeAttentionConfig
[docs]class RaggedDecodeAttention(Kernel[RaggedDecodeAttentionConfig, Array]):
"""Ragged Decode Attention with custom optimization logic.
Implements efficient attention for variable-length sequences during inference decode phase.
Uses sequence start/end markers to define valid attention ranges without padding overhead.
Features:
- Zero-padding overhead for variable-length sequences
- Sliding window attention for local context
- Logit soft capping for numerical stability
- Attention sink mechanism for long contexts
- Multiple platform support (Triton/Pallas/CUDA/XLA)
- Configurable block sizes for performance tuning
This implementation is particularly efficient for:
- Batch inference with varying prompt/generation lengths
- Serving workloads requiring dynamic batching
- Decoder-only models in generation mode
"""
def __init__(self):
"""Initialize Ragged Decode Attention module.
Sets up the kernel with the operation identifier for registry lookup
and configuration management.
"""
super().__init__(op_id="ragged_decode_attention")
[docs] def get_impl(self, cfg: RaggedDecodeAttentionConfig):
"""Get kernel implementation from registry.
Args:
cfg: Configuration specifying platform and backend preferences
Returns:
Callable kernel implementation for ragged decode attention
Raises:
ValueError: If no matching implementation is found for the configuration
"""
platform = detect_platform("ragged_decode_attention", cfg.platform, maybe_pallas=True)
return kernel_registry.get("ragged_decode_attention", platform=platform, backend=cfg.backend)
[docs] def run(
self,
query: Float[Array, "batch num_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
sequence_start: Int[Array, "batch"],
sequence_end: Int[Array, "batch"],
softmax_scale: float | None = None,
sliding_window: tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
softmax_aux: Float[Array, "num_sinks"] | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: RaggedDecodeAttentionConfig,
) -> Float[Array, "total_tokens num_q_heads head_dim"]:
"""Execute ragged decode attention with variable-length sequences.
Computes attention for batched queries where each sequence has a different
valid key-value range defined by sequence_start and sequence_end markers.
Args:
query: Query tensor [batch, num_heads, head_dim] (typically single decode step)
key: Key tensor [batch, seq_len, num_kv_heads, head_dim] (full context)
value: Value tensor [batch, seq_len, num_kv_heads, head_dim] (full context)
sequence_start: Start indices for valid KV range per sequence [batch]
sequence_end: End indices (exclusive) for valid KV range per sequence [batch]
softmax_scale: Scaling factor for attention scores (default: 1.0)
sliding_window: Optional (left, right) window sizes for local attention
logits_soft_cap: Optional soft cap to bound attention logits
softmax_aux: Optional attention sink logits for improved long-context performance
platform: Optional platform override ("triton", "pallas", "cuda", "xla")
cfg: Kernel configuration object containing block_size parameter
Returns:
Attention output [total_tokens, num_q_heads, head_dim]
Note:
The sequence_start and sequence_end arrays define which KV positions
are valid for each query. This enables efficient batching of sequences
with different lengths without padding overhead.
Example:
>>>
>>> sequence_start = jnp.array([0, 50])
>>> sequence_end = jnp.array([50, 150])
>>> out = ragged_decode_attention(q, k, v, sequence_start, sequence_end)
"""
if platform is not None:
cfg = RaggedDecodeAttentionConfig(
fwd_params=cfg.fwd_params,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
return impl(
query=query,
key=key,
value=value,
softmax_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
sliding_window=sliding_window,
softmax_aux=softmax_aux,
sequence_start=sequence_start,
sequence_end=sequence_end,
fwd_params=cfg.fwd_params,
)
[docs] def heuristic_cfg(self, inv: Invocation[RaggedDecodeAttentionConfig, Array]) -> RaggedDecodeAttentionConfig:
"""Provide default configuration optimized for decode attention.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Default KernelConfig with conservative block sizes suitable for
typical decode scenarios (small query sizes, variable KV lengths)
"""
return RaggedDecodeAttentionConfig(
fwd_params=FwdParams(
blocksize_heads=16,
num_key_splits=16,
kv_blocksize=128,
num_warps=4,
num_stages=2,
),
platform="auto",
backend="any",
)
[docs] def candidate_cfgs_gpu(self, inv: Invocation[RaggedDecodeAttentionConfig, Array]):
"""GPU/Triton candidates for ragged decode attention (bigger blocks + higher warps).
- Explores kv_blocksize up to 256 (when split_len allows)
- Tries blocksize_heads in {4, 8, 16} if grouped-heads permit
- Warps up to 8 (depending on kv_block/head_dim)
- Stages in {1, 2, 3} (kept low; smem-guarded)
- Prefers split_len near {128, 256, 512}; ensures split_len % kv_blocksize == 0
"""
q = inv.kwargs["query"]
k = inv.kwargs["key"]
v = inv.kwargs["value"]
seq_len = int(k.shape[1])
num_q_heads = int(q.shape[1])
num_kv_heads = int(k.shape[2])
head_dim = int(q.shape[-1])
dtype = q.dtype
assert num_kv_heads == int(v.shape[2])
assert head_dim == int(k.shape[-1]) == int(v.shape[-1])
assert num_q_heads % num_kv_heads == 0, "q_heads must be divisible by kv_heads"
grouped_heads = num_q_heads // num_kv_heads
preferred_split_lens = (64, 128, 256, 512)
def best_splits(n: int, targets=preferred_split_lens, min_len=32, max_len=8192):
divs = set()
r = int(math.sqrt(n))
for d in range(1, r + 1):
if n % d == 0:
divs.add(d)
divs.add(n // d)
valid = []
for s in sorted(divs):
sl = n // s
if min_len <= sl <= max_len:
valid.append((s, sl))
def score(sl):
return min(abs(sl - t) for t in targets)
valid.sort(key=lambda x: (score(x[1]), -x[1]))
return valid
split_candidates = best_splits(seq_len)
head_opts = [h for h in (4, 8) if h <= grouped_heads] or [min(grouped_heads, 4)]
kv_block_opts = [64, 128]
smem_limit = int(os.getenv("EJKERNEL_TRITON_SMEM_LIMIT", str(99 * 1024)))
elem_bytes = 2 if dtype in (jnp.float16, jnp.bfloat16) else 4
def next_pow2_ge(x, min_val=16):
return max(min_val, 1 << math.ceil(math.log2(max(1, x))))
block_headdim = next_pow2_ge(head_dim, 16)
def smem_est_bytes(block_heads: int, block_k: int, num_stages: int) -> int:
kv_bytes = 2 * block_k * block_headdim * elem_bytes
q_bytes = int(0.25 * block_heads * block_headdim * elem_bytes)
stage_factor = 1.0 + 0.5 * max(0, num_stages - 2)
fudge = 2.5
return int((kv_bytes + q_bytes) * stage_factor * fudge)
def warp_options(block_heads: int, block_k: int) -> list[int]:
opts = [2, 4]
if head_dim >= 128 or block_k >= 128:
opts.append(8)
return opts
def stage_options(block_k: int) -> list[int]:
return [1] if block_k <= 64 else [1, 2]
seeds = []
if seq_len % 64 == 0:
seeds.append((seq_len // 64, 64))
for s, sl in split_candidates[:6]:
if (s, sl) not in seeds:
seeds.append((s, sl))
max_candidates = int(os.getenv("EJKERNEL_RDA_MAX_CANDIDATES", "32"))
configs: list[RaggedDecodeAttentionConfig] = []
seen = set()
def try_add(H, K, s, sl):
if K > sl or sl % K != 0:
return False
for W in warp_options(H, K):
for S in stage_options(K):
if smem_est_bytes(H, K, S) > smem_limit:
continue
key = (H, K, s, W, S)
if key in seen:
continue
seen.add(key)
configs.append(
RaggedDecodeAttentionConfig(
fwd_params=FwdParams(
blocksize_heads=H,
num_key_splits=s,
kv_blocksize=K,
num_warps=W,
num_stages=S,
),
platform="pallas",
backend="gpu",
)
)
if len(configs) >= max_candidates:
return True
return False
for s, sl in seeds:
if try_add(4 if 4 in head_opts else head_opts[0], 64, s, sl):
return configs
for s, sl in seeds:
for H in head_opts:
for K in kv_block_opts:
if try_add(H, K, s, sl):
return configs
for s, sl in split_candidates[:12]:
for H in head_opts:
for K in kv_block_opts:
if try_add(H, K, s, sl):
return configs
if not configs:
H = min(4, grouped_heads) if grouped_heads >= 4 else grouped_heads
s = max(1, seq_len // 64)
s = s if seq_len % s == 0 else 1
sl = seq_len // s
K = 64 if sl % 64 == 0 else (32 if sl % 32 == 0 else 16)
try_add(H, K, s, sl)
return configs
[docs] def candidate_cfgs(self, inv: Invocation[RaggedDecodeAttentionConfig, Array]):
"""Generate candidate configurations for autotuning.
Creates multiple configurations optimized for different decode scenarios,
from small batches with short contexts to larger batches with longer contexts.
Args:
inv: Invocation object containing arguments and metadata
Returns:
List of candidate configurations to benchmark during autotuning
Note:
Decode attention typically has small query dimensions (batch size),
so candidates focus on optimizing block sizes.
"""
block_configs = [
(128, 4, 1),
(256, 4, 1),
(512, 8, 2),
]
candidates = []
for block_size, num_warps, num_stages in block_configs:
candidates.append(
RaggedDecodeAttentionConfig(
fwd_params=FwdParams(
blocksize_heads=16,
num_key_splits=16,
kv_blocksize=block_size,
num_warps=num_warps,
num_stages=num_stages,
),
platform="auto",
backend="any",
)
)
return candidates
[docs] def create_shard_map_wrapper(
self,
query: Float[Array, "batch num_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
sequence_start: Int[Array, "batch"],
sequence_end: Int[Array, "batch"],
softmax_scale: float | None = None,
sliding_window: tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
softmax_aux: Float[Array, "num_sinks"] | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: RaggedDecodeAttentionConfig | None = None,
mesh: Mesh | None = None,
in_specs: tuple[PartitionSpec, ...] | None = None,
out_specs: PartitionSpec | None = None,
check_vma: bool = False,
):
"""Create a shard_map wrapper for distributed execution.
Creates a wrapper function that applies shard_map to distribute the ragged decode attention
computation across devices according to the provided sharding specifications.
Args:
query: Query tensor [batch, num_heads, head_dim]
key: Key tensor [batch, seq_len, num_kv_heads, head_dim]
value: Value tensor [batch, seq_len, num_kv_heads, head_dim]
sequence_start: Start indices for valid KV range per sequence [batch]
sequence_end: End indices for valid KV range per sequence [batch]
softmax_scale: Scaling factor for attention scores
sliding_window: Optional (left, right) window sizes for local attention
logits_soft_cap: Optional soft cap to bound attention logits
softmax_aux: Optional attention sink logits
platform: Platform to use for execution
cfg: Configuration for the kernel
mesh: JAX mesh for distributed execution
in_specs: Partition specifications for input tensors
out_specs: Partition specifications for output tensor
check_vma: Whether to check for valid memory access patterns
Returns:
Tuple of (shard_map_fn, call_args) where shard_map_fn is the wrapped
function and call_args are the arguments to pass to it.
"""
assert mesh is not None, "mesh must be provided for shard_map execution"
assert in_specs is not None, "in_specs must be provided for shard_map execution"
assert out_specs is not None, "out_specs must be provided for shard_map execution"
def _wrapper(
query,
key,
value,
sequence_start,
sequence_end,
softmax_aux,
):
return self.run(
query=query,
key=key,
value=value,
sequence_start=sequence_start,
sequence_end=sequence_end,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
softmax_aux=softmax_aux,
platform=platform,
cfg=cfg or self.heuristic_cfg(None),
)
shard_map_fn = shard_map(
_wrapper,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
call_args = (
query,
key,
value,
sequence_start,
sequence_end,
softmax_aux,
)
return shard_map_fn, call_args
_ragged_decode_attention_executor: Executor[RaggedDecodeAttentionConfig, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(
allow_autotune=True,
cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"),
validate_backward=False,
),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("ragged-decode-attention", cfg_type=RaggedDecodeAttentionConfig),
)
)
[docs]def ragged_decode_attention(
query: Float[Array, "batch num_heads head_dim"] | Float[Array, "batch 1 num_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
sequence_start: Int[Array, "batch"],
sequence_end: Int[Array, "batch"],
softmax_aux: Float[Array, "num_sinks"] | None = None,
/,
*,
softmax_scale: float | None = None,
sliding_window: tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: RaggedDecodeAttentionConfig | None = None,
mesh: Mesh | None = None,
in_specs: tuple[PartitionSpec | None, ...] | None = None,
out_specs: PartitionSpec | None = None,
) -> Float[Array, "total_tokens num_q_heads head_dim"]:
"""Execute ragged decode attention with automatic optimization.
Efficiently computes attention for variable-length sequences during the decode phase,
using start/end indices to define valid attention ranges without padding overhead.
Args:
query: Query tensor [batch, num_heads, head_dim] for current decode step
key: Full key context [batch, seq_len, num_kv_heads, head_dim]
value: Full value context [batch, seq_len, num_kv_heads, head_dim]
sequence_start: Start index of valid KV range per sequence [batch]
sequence_end: End index (exclusive) of valid KV range per sequence [batch]
softmax_scale: Attention score scaling factor (default: 1.0)
sliding_window: Optional (left, right) window sizes for local attention
logits_soft_cap: Optional soft cap for attention logits (improves stability)
softmax_aux: Optional attention sink values for long-context handling
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
cfg: Optional config override (block_size is set via cfg)
Returns:
Attention output [total_tokens, num_q_heads, head_dim]
Example:
>>>
>>> out = ragged_decode_attention(q, k, v, starts, ends)
>>>
>>>
>>> from ejkernel.modules.operations.configs import RaggedDecodeAttentionConfig
>>> cfg = RaggedDecodeAttentionConfig(block_size=128)
>>> out = ragged_decode_attention(
... q, k, v, starts, ends,
... sliding_window=(256, 256),
... cfg=cfg
... )
>>>
>>>
>>> out = ragged_decode_attention(
... q, k, v, starts, ends,
... logits_soft_cap=50.0,
... softmax_scale=0.125
... )
>>>
>>>
>>> out = ragged_decode_attention(..., platform="triton")
Note:
This function is optimized for decode scenarios where query size is small
(typically batch_size) and KV length varies per sequence. For prefill phase
with large queries, consider using standard flash_attention instead.
"""
method = None
if mesh is not None and in_specs is not None and out_specs is not None:
method = "shard_map"
was4d = query.ndim == 4
if was4d:
query = query[:, -1, :, :]
out = _ragged_decode_attention_executor(
RaggedDecodeAttention(),
query=query,
key=key,
value=value,
softmax_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
sliding_window=sliding_window,
softmax_aux=softmax_aux,
sequence_start=sequence_start,
sequence_end=sequence_end,
platform=platform,
method=method,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
_cfg=cfg,
)
if was4d:
out = jnp.expand_dims(out, 1)
return out