Source code for ejkernel.kernels._xla.recurrent._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Recurrent attention interface for linear-time sequence processing.

This module provides the public API for recurrent linear attention with
O(N) complexity. Supports various gating mechanisms (GLA, Lightning)
and provides custom VJP for efficient gradient computation.
"""

from functools import partial

import jax
import jax.numpy as jnp
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_bwd import _recurrent_attention_bwd
from ._xla_impl_fwd import _recurrent_attention_fwd, _recurrent_attention_varlen_fwd


@partial(jax.custom_vjp, nondiff_argnums=(4, 7, 9, 10))
def _recurrent_core(
    query: Float[Array, "batch seq_len num_heads head_dim"],
    key: Float[Array, "batch seq_len num_heads head_dim"],
    value: Float[Array, "batch seq_len num_heads head_dim"],
    g: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    g_gamma: Float[Array, "... num_heads"] | None = None,
    gk: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    gv: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    softmax_scale: float | None = None,
    initial_state: Float[Array, "... num_heads head_dim head_dim"] | None = None,
    reverse: bool = False,
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "... num_heads head_dim head_dim"]]:
    """Core recurrent attention with custom VJP."""
    if cu_seqlens is not None:
        return _recurrent_attention_varlen_fwd(
            query, key, value, cu_seqlens, g, g_gamma, gk, gv, softmax_scale, initial_state, reverse
        )
    else:
        output, _, final_state = _recurrent_attention_fwd(
            query, key, value, g, g_gamma, gk, gv, softmax_scale, initial_state, reverse
        )
        return output, final_state


def _recurrent_fwd(
    query: Float[Array, "batch seq_len num_heads head_dim"],
    key: Float[Array, "batch seq_len num_heads head_dim"],
    value: Float[Array, "batch seq_len num_heads head_dim"],
    g: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    g_gamma: Float[Array, "... num_heads"] | None = None,
    gk: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    gv: Float[Array, "batch seq_len num_heads head_dim"] | None = None,
    softmax_scale: float | None = None,
    initial_state: Float[Array, "batch num_heads head_dim head_dim"] | None = None,
    reverse: bool = False,
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> tuple[
    tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "batch num_heads head_dim head_dim"]],
    tuple,
]:
    """Forward with residuals for backward."""

    if softmax_scale is None:
        softmax_scale = 1.0 / jnp.sqrt(query.shape[-1]).astype(jnp.float32)
    if g is None:
        g = jnp.zeros_like(query)
    if g_gamma is None:
        g_gamma = jnp.zeros((query.shape[2],))
    if gk is None:
        gk = jnp.zeros_like(query)
    if gv is None:
        gv = jnp.zeros_like(query)
    if initial_state is None:
        initial_state = jnp.zeros((query.shape[0], query.shape[2], query.shape[3], query.shape[3]))

    if cu_seqlens is not None:
        raise NotImplementedError("Custom backward for varlen not yet implemented")
    else:
        o, hidden_states, final_state = _recurrent_attention_fwd(
            query, key, value, g, g_gamma, gk, gv, softmax_scale, initial_state, reverse
        )

    residuals = (query, key, value, g, g_gamma, gk, gv, hidden_states, initial_state)
    return (o, final_state), residuals


def _recurrent_bwd(
    g_gamma_nondiff: Float[Array, "... num_heads"],
    scale_nondiff: float | None,
    reverse: bool,
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None,
    residuals: tuple,
    grads: tuple,
) -> tuple:
    """Backward pass with custom implementation."""
    query, key, value, g, g_gamma, gk, gv, hidden_states, initial_state = residuals
    do, dfinal_state = grads

    if scale_nondiff is None:
        softmax_scale = 1.0 / jnp.sqrt(query.shape[-1]).astype(jnp.float32)
    else:
        softmax_scale = scale_nondiff

    if cu_seqlens is not None:
        raise NotImplementedError("Variable-length backward not yet implemented")

    dq, dk, dv, dg, dgk, dgv, dinitial_state = _recurrent_attention_bwd(
        query, key, value, g, g_gamma, gk, gv, hidden_states, do, dfinal_state, softmax_scale, initial_state, reverse
    )

    return (dq, dk, dv, dg, dgk, dgv, dinitial_state)


_recurrent_core.defvjp(_recurrent_fwd, _recurrent_bwd)


[docs]@kernel_registry.register("recurrent", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def recurrent( query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_kv_heads qk_head_dim"], value: Float[Array, "batch seq_len num_kv_heads v_head_dim"], g: Float[Array, "batch seq_len num_heads qk_head_dim"] | None = None, g_gamma: Float[Array, "... num_heads"] | None = None, gk: Float[Array, "batch seq_len num_heads qk_head_dim"] | None = None, gv: Float[Array, "batch seq_len num_heads v_head_dim"] | None = None, softmax_scale: float | None = None, initial_state: Float[Array, "... num_heads qk_head_dim v_head_dim"] | None = None, reverse: bool = False, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> tuple[Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "... num_heads qk_head_dim v_head_dim"]]: """ Recurrent linear attention with O(N) complexity using JAX/XLA. This implements linear attention as a recurrent process, maintaining a hidden state that accumulates key-value information sequentially. Unlike standard O(N²) attention, this achieves O(N) complexity by processing tokens one at a time. The core update is: h_t = decay_t * h_{t-1} + k_t^T ⊗ v_t o_t = h_t @ q_t Supports various gating mechanisms for different attention variants. Args: query: Query tensor [batch, seq_len, num_heads, head_dim] key: Key tensor [batch, seq_len, num_heads, head_dim] value: Value tensor [batch, seq_len, num_heads, head_dim] g: Optional gate tensor for GLA-style gating [batch, seq_len, num_heads, head_dim] g_gamma: Optional per-head decay factor [num_heads] for Lightning attention gk: Optional gate applied to keys [batch, seq_len, num_heads, head_dim] gv: Optional gate applied to values [batch, seq_len, num_heads, head_dim] softmax_scale: Query scaling factor. If None, defaults to 1/sqrt(head_dim) initial_state: Initial hidden state [batch, num_heads, head_dim, head_dim] reverse: If True, process sequence in reverse order cu_seqlens: Cumulative sequence lengths for variable-length inputs [num_seqs+1] Returns: Tuple of: - output: Attention output [batch, seq_len, num_heads, head_dim] - final_state: Final hidden state [batch, num_heads, head_dim, head_dim] Examples: >>> >>> query = jnp.ones((2, 100, 8, 64)) >>> key = jnp.ones((2, 100, 8, 64)) >>> value = jnp.ones((2, 100, 8, 64)) >>> output, final_state = recurrent(query, key, value) >>> output.shape (2, 100, 8, 64) >>> >>> g = jnp.ones((2, 100, 8, 64)) >>> output, final_state = recurrent(query, key, value, g=g) >>> >>> g_gamma = -jnp.arange(8, dtype=jnp.float32) * 0.1 >>> output, final_state = recurrent(query, key, value, g_gamma=g_gamma) >>> >>> query = jnp.ones((150, 8, 64)) >>> key = jnp.ones((150, 8, 64)) >>> value = jnp.ones((150, 8, 64)) >>> cu_seqlens = jnp.array([0, 50, 100, 150]) >>> output, final_state = recurrent(query, key, value, cu_seqlens=cu_seqlens) >>> output.shape (150, 8, 64) >>> final_state.shape (3, 8, 64, 64) """ return _recurrent_core(query, key, value, g, g_gamma, gk, gv, softmax_scale, initial_state, reverse, cu_seqlens)