Source code for ejkernel.kernels._triton.recurrent._triton_impl_fwd

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import jax
import triton
import triton.language as tl
from jax import numpy as jnp
from jaxtyping import Array, Float, Int

from ejkernel.callib import cdiv, triton_call


@triton.autotune(
    configs=[triton.Config({}, num_warps=num_warps) for num_warps in [4, 8]],
    key=["key_chunk_size", "blocksize_v", "USE_G", "USE_G_GAMMA", "USE_GK", "USE_GV"],
)
@triton.heuristics(
    {
        "USE_INITIAL_STATE": lambda args: args["h0"] != 1,
        "STORE_FINAL_STATE": lambda args: args["ht"] != 1,
        "IS_VARLEN": lambda args: args["cu_seqlens"] != 1,
    }
)
@triton.jit
def fwd_kernel(
    q,
    k,
    v,
    g,
    g_gamma,
    gk,
    gv,
    h0,
    cu_seqlens,
    softmax_scale,
    o,
    ht,
    T: tl.constexpr,
    B: tl.constexpr,
    H: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    key_chunk_size: tl.constexpr,
    blocksize_v: tl.constexpr,
    REVERSE: tl.constexpr,
    USE_G: tl.constexpr,
    USE_G_GAMMA: tl.constexpr,
    USE_GK: tl.constexpr,
    USE_GV: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    STORE_FINAL_STATE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    """Triton kernel for forward pass of recurrent linear attention.

    Processes sequences step-by-step with O(N) complexity, maintaining a
    hidden state that accumulates key-value information. Supports various
    gating mechanisms and both forward/reverse processing.

    Args:
        q, k, v: Query, key, value tensor pointers
        g: Optional gate tensor for GLA-style gating
        g_gamma: Optional decay factor for Lightning attention
        gk, gv: Optional gates applied to keys and values
        h0: Initial hidden state pointer
        cu_seqlens: Cumulative sequence lengths for variable-length mode
        softmax_scale: Query scaling factor
        o: Output tensor pointer
        ht: Final hidden state pointer
        T, B, H, K, V: Tensor dimensions (sequence, batch, heads, key/value dims)
        key_chunk_size, blocksize_v: Block sizes for tiling
        REVERSE: Process sequence in reverse order
        USE_G, USE_G_GAMMA, USE_GK, USE_GV: Gating configuration flags
        USE_INITIAL_STATE: Whether to use initial hidden state
        STORE_FINAL_STATE: Whether to store final hidden state
        IS_VARLEN: Variable-length sequence mode
    """
    i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
    i_n, i_h = i_nh // H, i_nh % H
    if IS_VARLEN:
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
        scope = T
        T = eos - bos
    else:
        bos, eos = i_n * T, i_n * T + T
        scope = B * T

    o_k = i_k * key_chunk_size + tl.arange(0, key_chunk_size)
    o_v = i_v * blocksize_v + tl.arange(0, blocksize_v)
    p_q = q + (bos + ((T - 1) if REVERSE else 0)) * H * K + i_h * K + o_k
    p_k = k + (bos + ((T - 1) if REVERSE else 0)) * H * K + i_h * K + o_k
    p_v = v + (bos + ((T - 1) if REVERSE else 0)) * H * V + i_h * V + o_v
    p_o = o + ((i_k * scope + bos) + ((T - 1) if REVERSE else 0)) * H * V + i_h * V + o_v
    if USE_G:
        p_g = g + (bos + ((T - 1) if REVERSE else 0)) * H + i_h
    if USE_GK:
        p_gk = gk + (bos + ((T - 1) if REVERSE else 0)) * H * K + i_h * K + o_k
    if USE_GV:
        p_gv = gv + (bos + ((T - 1) if REVERSE else 0)) * H * V + i_h * V + o_v
    if USE_G_GAMMA:
        b_g_gamma = tl.load(g_gamma + i_h)

    mask_k = o_k < K
    mask_v = o_v < V
    mask_h = mask_k[:, None] & mask_v[None, :]
    b_h = tl.zeros([key_chunk_size, blocksize_v], dtype=tl.float32)

    if USE_INITIAL_STATE:
        p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
        b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

    for _ in range(0, T):
        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * softmax_scale
        b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
        b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
        if USE_G:
            b_g = tl.load(p_g).to(tl.float32)
            b_h = b_h * tl.exp(b_g)
        if USE_G_GAMMA:
            b_h = b_h * tl.exp(b_g_gamma)
        if USE_GK:
            b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
            b_h = b_h * tl.exp(b_gk[:, None])
        if USE_GV:
            b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
            b_h = b_h * tl.exp(b_gv[None, :])
        b_h += b_k[:, None] * b_v[None, :]
        b_o = b_h * b_q[:, None]
        b_o = tl.sum(b_o, axis=0)
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
        p_q += (-1 if REVERSE else 1) * H * K
        p_k += (-1 if REVERSE else 1) * H * K
        p_v += (-1 if REVERSE else 1) * H * V
        p_o += (-1 if REVERSE else 1) * H * V
        if USE_G:
            p_g += (-1 if REVERSE else 1) * H
        if USE_GK:
            p_gk += (-1 if REVERSE else 1) * H * K
        if USE_GV:
            p_gv += (-1 if REVERSE else 1) * H * V

    if STORE_FINAL_STATE:
        p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)


[docs]def fwd_triton_impl( q: Float[Array, "batch seq_len num_heads head_dim"], k: Float[Array, "batch seq_len num_heads head_dim"], v: Float[Array, "batch seq_len num_heads head_dim"], g: Float[Array, "batch seq_len num_heads head_dim"] | None = None, g_gamma: Float[Array, "batch num_heads"] | None = None, gk: Float[Array, "batch seq_len num_heads head_dim"] | None = None, gv: Float[Array, "batch seq_len num_heads head_dim"] | None = None, softmax_scale: float | None = None, initial_state: Float[Array, "batch num_heads head_dim head_dim"] | None = None, reverse: bool = False, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "batch num_heads head_dim head_dim"]]: B, T, H, K, V = *k.shape, v.shape[-1] N = B if cu_seqlens is None else len(cu_seqlens) - 1 key_chunk_size, blocksize_v = min(K, 64), min(V, 64) NumKBlocks, NumVBlocks = cdiv(K, key_chunk_size), cdiv(V, blocksize_v) h0 = initial_state ht_shape = (N, H, K, V) out_shape = (NumKBlocks, *v.shape) grid = (NumVBlocks, NumKBlocks, N * H) metaparams = dict( T=T, B=B, H=H, K=K, V=V, key_chunk_size=key_chunk_size, blocksize_v=blocksize_v, USE_G=g is not None, USE_G_GAMMA=g_gamma is not None, USE_GK=gk is not None, USE_GV=gv is not None, REVERSE=reverse, ) out, ht = triton_call( q, k, v, g if g is not None else 1, g_gamma if g_gamma is not None else 1, gk if gk is not None else 1, gv if gv is not None else 1, h0 if h0 is not None else 1, cu_seqlens if cu_seqlens is not None else 1, softmax_scale if softmax_scale is not None else 1, kernel=fwd_kernel, out_shape=[ jax.ShapeDtypeStruct(out_shape, q.dtype), jax.ShapeDtypeStruct(ht_shape, jnp.float32), ], name="ejkernel::triton::recurrent_fwd", grid=grid, **metaparams, ) out = jnp.sum(out, axis=0) return out, ht