Source code for ejkernel.kernels._xla.kernel_delta_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Kernel Delta Attention interface for linear-time attention with delta updates.

This module provides the public API for KDA (Kernel Delta Attention), a linear
attention variant using delta rule updates for memory management. Supports
chunked, recurrent, and single-step computation modes.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float

from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_fwd import _chunk_kda_fwd, _recurrent_kda_fwd, _single_step_kda_fwd


[docs]def kda_decay( gate: Float[Array, "batch seq_len num_heads"], A_log: Float[Array, "num_heads"], dt_bias: Float[Array, "num_heads"], ) -> Float[Array, "batch seq_len num_heads"]: """Compute KDA per-token decay from gate, A_log, and dt_bias. This function computes the decay term used in Kernel Delta Attention, following the Mamba-style discretization where decay controls how much of the previous state is retained. The computation is: A = -exp(A_log) # Ensure A is negative for stability decay = A * softplus(gate + dt_bias) Args: gate: Gating signal from input projection Shape: [batch, seq_len, num_heads] A_log: Learnable log-scale decay parameter (typically initialized near 0) Shape: [num_heads] dt_bias: Learnable bias added to gate before softplus Shape: [num_heads] Returns: Per-token decay values (always <= 0 for stable state decay) Shape: [batch, seq_len, num_heads] Example: >>> gate = jnp.zeros((2, 10, 4)) # batch=2, seq_len=10, num_heads=4 >>> A_log = jnp.zeros((4,)) >>> dt_bias = jnp.zeros((4,)) >>> decay = kda_decay(gate, A_log, dt_bias) >>> assert jnp.all(decay <= 0) # Decay is always non-positive """ A = -jnp.exp(A_log.astype(jnp.float32)) return A[None, None, :] * jax.nn.softplus(gate.astype(jnp.float32) + dt_bias.astype(jnp.float32))
[docs]@kernel_registry.register("kda", Platform.XLA, Backend.ANY) @kernel_registry.register("kernel_delta_attention", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def kernel_delta_attention( query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_heads qk_head_dim"], value: Float[Array, "batch seq_len num_heads v_head_dim"], beta: Float[Array, "batch seq_len num_heads"], decay: Float[Array, "batch seq_len num_heads"] | None = None, *, softmax_scale: float | None = None, chunk_size: int = 64, initial_state: Float[Array, "batch num_heads qk_head_dim v_head_dim"] | None = None, use_qk_l2norm: bool = True, use_chunked: bool = True, ) -> tuple[ Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "batch num_heads qk_head_dim v_head_dim"], ]: """Kernel Delta Attention (KDA) linear attention using XLA backend. KDA is a linear attention variant that maintains a key-value memory matrix and uses delta updates to efficiently store and retrieve information. It combines ideas from linear attention and delta networks for O(N) complexity. The core recurrence is: h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t Where: - h_t is the [head_dim, value_dim] memory matrix per head - exp(decay_t) controls memory retention (decay <= 0 for stability) - beta_t controls the learning rate for delta updates - The delta term (v_t - h_{t-1} @ k_t) computes what's new in v_t Algorithm Modes: - Chunked (default): Parallel within chunks, sequential across chunks. Better throughput for training with moderate sequence lengths. - Recurrent: Pure sequential scan. Lower memory, good for inference. - Single-step: Optimized path when seq_len=1 with initial_state. Args: query: Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim] key: Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim] value: Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim] beta: Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads] decay: Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention) softmax_scale: Scaling factor for queries. If None, uses head_dim^-0.5 chunk_size: Block size for chunked algorithm (default: 64) initial_state: Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim] use_qk_l2norm: If True, apply L2 normalization to queries and keys before attention. Improves stability (default: True) use_chunked: If True, use chunked algorithm; else use recurrent scan Chunked is faster for training, recurrent for long inference Returns: Tuple of: - output: Attention output Shape: [batch, seq_len, num_heads, v_head_dim] - final_state: Final memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim] Example: >>> import jax.numpy as jnp >>> from jax import random >>> >>> # Basic usage >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output, state = kernel_delta_attention(q, k, v, beta, decay, chunk_size=16) >>> output.shape (2, 64, 8, 32) >>> >>> # Incremental inference >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, initial_state=state ... ) References: - Delta Networks: https://arxiv.org/abs/1612.04859 - Linear Transformers: https://arxiv.org/abs/2006.16236 """ if softmax_scale is None: softmax_scale = query.shape[-1] ** -0.5 q = query.transpose(0, 2, 1, 3) k = key.transpose(0, 2, 1, 3) v = value.transpose(0, 2, 1, 3) b = beta.transpose(0, 2, 1) d = decay.transpose(0, 2, 1) if decay is not None else None if query.shape[1] == 1 and initial_state is not None: out, final_state = _single_step_kda_fwd( query=q, key=k, value=v, beta=b, decay=d, softmax_scale=softmax_scale, recurrent_state=initial_state, use_qk_l2norm=use_qk_l2norm, ) elif use_chunked: out, final_state = _chunk_kda_fwd( query=q, key=k, value=v, beta=b, decay=d, softmax_scale=softmax_scale, chunk_size=chunk_size, initial_state=initial_state, use_qk_l2norm=use_qk_l2norm, ) else: out, final_state = _recurrent_kda_fwd( query=q, key=k, value=v, beta=b, decay=d, softmax_scale=softmax_scale, initial_state=initial_state, use_qk_l2norm=use_qk_l2norm, ) out = out.transpose(0, 2, 1, 3) return out, final_state
# Alias for convenience. kda = kernel_delta_attention