# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recurrent Attention module with automatic optimization.
This module implements recurrent-style attention mechanisms that maintain and update
hidden states across sequence positions. Unlike standard attention which computes
all positions independently, recurrent attention processes sequences sequentially
with stateful computation.
Features:
- Stateful attention with initial_state support
- Separate gating for queries (g), keys (gk), and values (gv)
- Layer-wise gating control via g_gamma
- Bidirectional processing support (forward and reverse)
- Variable-length sequence handling
This is particularly useful for:
- Linear-time attention mechanisms
- Models requiring sequential dependency modeling
- Architectures with explicit state propagation
- Efficient inference with incremental state updates
"""
from __future__ import annotations
import os
import typing
from typing import Literal
from jaxtyping import Array, Float, Int
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
ConfigCache,
ConfigSelectorChain,
Executor,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ..base import detect_platform
from .configs import RecurrentAttentionConfig
[docs]class RecurrentAttention(Kernel[RecurrentAttentionConfig, Array]):
"""Recurrent Attention with custom optimization logic.
Implements attention with recurrent state updates, enabling linear-time complexity
for certain attention patterns. The mechanism maintains a hidden state that is
updated at each sequence position.
Features:
- Stateful computation with hidden state propagation
- Multiple gating mechanisms (g, gk, gv, g_gamma)
- Forward and reverse processing modes
- Support for initial states
- Variable-length sequence handling via cu_seqlens
- Multiple platform support (Triton/Pallas/CUDA/XLA)
The gating mechanisms provide fine-grained control:
- g: Query-level gates
- gk: Key-level gates
- gv: Value-level gates
- g_gamma: Layer-level gates
"""
def __init__(self):
"""Initialize Recurrent Attention module.
Sets up the kernel with the operation identifier for registry lookup
and configuration management.
"""
super().__init__(op_id="recurrent")
[docs] def get_impl(self, cfg: RecurrentAttentionConfig):
"""Get kernel implementation from registry.
Args:
cfg: Configuration specifying platform and backend
Returns:
Callable kernel implementation for recurrent attention
Raises:
ValueError: If no matching implementation is found
"""
platform = detect_platform("recurrent", cfg.platform)
return kernel_registry.get("recurrent", platform=platform, backend=cfg.backend)
[docs] def run(
self,
query: Float[Array, "batch seq_len num_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
g: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
g_gamma: Float[Array, "batch num_heads"] | None = None,
gk: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
gv: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
softmax_scale: float | None = None,
initial_state: Float[Array, "batch num_heads head_dim head_dim"] | None = None,
reverse: bool = False,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
return_state: bool = False,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: RecurrentAttentionConfig,
) -> (
Float[Array, "batch seq_len num_heads head_dim"]
| tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "batch num_heads head_dim head_dim"]]
):
"""Execute recurrent attention with stateful computation.
Args:
query: Query tensor [batch, seq_len, num_heads, head_dim]
key: Key tensor [batch, seq_len, num_kv_heads, head_dim]
value: Value tensor [batch, seq_len, num_kv_heads, head_dim]
g: Query-level gating tensor [batch, seq_len, num_heads, head_dim]
g_gamma: Layer-level gating parameter [batch, num_heads]
gk: Key-level gating tensor [batch, seq_len, num_heads, head_dim]
gv: Value-level gating tensor [batch, seq_len, num_heads, head_dim]
softmax_scale: Optional scaling factor for attention scores
initial_state: Initial hidden state [batch, num_heads, head_dim, head_dim]
reverse: If True, process sequence in reverse order
cu_seqlens: Cumulative sequence lengths for variable-length sequences
return_state: If True, return tuple (output, final_state) instead of just output
platform: Optional platform override ("triton", "pallas", "cuda", "xla")
cfg: Kernel configuration object
Returns:
If return_state=False: Attention output [batch, seq_len, num_heads, head_dim]
If return_state=True: Tuple of (output, final_state) where final_state
is [batch, num_heads, head_dim, head_dim]
Note:
All gating parameters (g, gk, gv, g_gamma) are optional. When provided,
they enable more sophisticated gated recurrent mechanisms.
"""
if platform is not None:
cfg = RecurrentAttentionConfig(
block_q=cfg.block_q,
block_k=cfg.block_k,
block_d=cfg.block_d if hasattr(cfg, "block_d") else None,
num_warps=cfg.num_warps,
num_stages=cfg.num_stages,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
result = impl(
query=query,
key=key,
value=value,
g=g,
g_gamma=g_gamma,
gk=gk,
gv=gv,
softmax_scale=softmax_scale,
initial_state=initial_state,
reverse=reverse,
cu_seqlens=cu_seqlens,
)
if isinstance(result, tuple):
if return_state:
return result
else:
return result[0]
return result
[docs] def heuristic_cfg(self, inv: Invocation[RecurrentAttentionConfig, Array]) -> RecurrentAttentionConfig:
"""Provide default configuration with block sizes.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Default configuration with conservative block sizes suitable for
typical recurrent attention workloads with stateful computation
"""
return RecurrentAttentionConfig(
block_q=64,
block_k=64,
block_d=64,
num_warps=4,
num_stages=1,
platform="auto",
backend="any",
)
[docs] def candidate_cfgs(self, inv: Invocation[RecurrentAttentionConfig, Array]):
"""Generate candidate configurations for autotuning.
Args:
inv: Invocation object containing arguments and metadata
Returns:
List of candidate configurations to benchmark during autotuning
Note:
Recurrent attention performance is sensitive to state update patterns.
Candidates are chosen to balance sequential processing efficiency.
"""
block_configs = [
(64, 64, 64, 4, 1),
(128, 64, 64, 4, 2),
(128, 128, 64, 8, 2),
]
candidates = []
for block_q, block_k, block_d, num_warps, num_stages in block_configs:
candidates.append(
RecurrentAttentionConfig(
block_q=block_q,
block_k=block_k,
block_d=block_d,
num_warps=num_warps,
num_stages=num_stages,
platform="auto",
backend="any",
)
)
return candidates
_recurrent_executor: Executor[RecurrentAttentionConfig, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(
allow_autotune=True,
cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"),
validate_backward=True,
),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("recurrent"),
)
)
[docs]def recurrent_attention(
query: Float[Array, "batch seq_len num_heads head_dim"],
key: Float[Array, "batch seq_len num_kv_heads head_dim"],
value: Float[Array, "batch seq_len num_kv_heads head_dim"],
g: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
g_gamma: Float[Array, "batch num_heads"] | None = None,
gk: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
gv: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
initial_state: Float[Array, "batch num_heads head_dim head_dim"] | None = None,
cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
/,
*,
softmax_scale: float | None = None,
reverse: bool = False,
return_state: bool = False,
platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: RecurrentAttentionConfig | None = None,
) -> (
Float[Array, "batch seq_len num_heads head_dim"]
| tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "batch num_heads head_dim head_dim"]]
):
"""Execute recurrent attention with automatic optimization.
Recurrent attention processes sequences with stateful computation,
maintaining hidden states across timesteps for efficient sequential processing.
Args:
query: Query tensor [batch, seq_len, num_heads, head_dim]
key: Key tensor [batch, seq_len, num_kv_heads, head_dim]
value: Value tensor [batch, seq_len, num_kv_heads, head_dim]
g: Gating tensor for query [batch, seq_len, num_heads, head_dim]
g_gamma: Gating gamma [batch, num_heads]
gk: Gating tensor for keys [batch, seq_len, num_heads, head_dim]
gv: Gating tensor for values [batch, seq_len, num_heads, head_dim]
softmax_scale: Scaling factor for attention
initial_state: Initial hidden state
reverse: Whether to process sequence in reverse
cu_seqlens: Cumulative sequence lengths for variable-length sequences
return_state: If True, return tuple (output, final_state) instead of just output
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
Returns:
If return_state=False: Attention output with same shape as query
If return_state=True: Tuple of (output, final_state)
Example:
>>>
>>> out = recurrent_attention(query, key, value)
>>>
>>>
>>> out = recurrent_attention(query, key, value, g=gates, gk=key_gates, gv=value_gates)
>>>
>>>
>>> out = recurrent_attention(query, key, value, platform="xla")
"""
return _recurrent_executor(
RecurrentAttention(),
query=query,
key=key,
value=value,
g=g,
g_gamma=g_gamma,
gk=gk,
gv=gv,
softmax_scale=softmax_scale,
initial_state=initial_state,
reverse=reverse,
cu_seqlens=cu_seqlens,
return_state=return_state,
platform=platform,
_cfg=cfg,
)