Source code for ejkernel.modules.operations.kernel_delta_attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Kernel Delta Attention (KDA) operation module with automatic optimization.

This module provides the KernelDeltaAttention operation, a linear attention
variant that uses delta rule updates to maintain a key-value memory matrix.
KDA achieves O(N) complexity while supporting efficient incremental inference.

Key characteristics of KDA:
    - Memory matrix: [num_heads, head_dim, value_dim] per batch
    - Delta updates: Only stores what's new in each value
    - Decay mechanism: Controls memory retention over time
    - Beta parameter: Per-token learning rate for updates

The core recurrence:
    h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t))
    o_t = h_t @ q_t

Algorithm modes:
    - Chunked (default): Parallel within chunks for training efficiency
    - Recurrent: Pure sequential scan for memory efficiency
    - Single-step: Optimized path for autoregressive generation

Features:
    - Automatic platform selection (XLA primary)
    - Configuration caching for consistent performance
    - L2 normalization option for stability
    - State passthrough for incremental inference

Example:
    >>> from ejkernel.modules.operations import kernel_delta_attention
    >>>
    >>> # Training forward pass
    >>> output = kernel_delta_attention(
    ...     query, key, value, beta, decay,
    ...     chunk_size=64, use_chunked=True,
    ... )
    >>>
    >>> # Inference with state
    >>> output, state = kernel_delta_attention(
    ...     query, key, value, beta, decay,
    ...     return_state=True,
    ... )
    >>> # Next token
    >>> output_new, state_new = kernel_delta_attention(
    ...     q_new, k_new, v_new, beta_new, decay_new,
    ...     initial_state=state, return_state=True,
    ... )

References:
    - Delta Networks: https://arxiv.org/abs/1612.04859
    - Linear Transformers: https://arxiv.org/abs/2006.16236
"""

from __future__ import annotations

import typing
from typing import Literal

from jaxtyping import Array, Float

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache

from ..base import detect_platform
from .configs import KernelDeltaAttentionConfig


[docs]class KernelDeltaAttention(Kernel[KernelDeltaAttentionConfig, Array]): """Kernel Delta Attention (KDA) operation. A linear attention mechanism using delta rule updates to maintain an associative memory matrix. Supports both training (chunked) and inference (single-step) modes with O(N) complexity. The operation maintains a [head_dim, value_dim] memory matrix per head that stores key-value associations. The delta update mechanism ensures only novel information is added, improving memory efficiency. Attributes: op_id: Operation identifier for registry lookup ("kernel_delta_attention") """ def __init__(self): """Initialize KernelDeltaAttention operation. Sets up the kernel with operation identifier for registry lookup and configuration management. """ super().__init__(op_id="kernel_delta_attention")
[docs] def get_impl(self, cfg: KernelDeltaAttentionConfig): """Get kernel implementation from registry. Args: cfg: Configuration specifying platform and backend Returns: Callable kernel implementation for KDA Raises: ValueError: If no matching implementation is found """ platform = detect_platform("kernel_delta_attention", cfg.platform) return kernel_registry.get("kernel_delta_attention", platform=platform, backend=cfg.backend)
[docs] def run( self, query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_heads qk_head_dim"], value: Float[Array, "batch seq_len num_heads v_head_dim"], beta: Float[Array, "batch seq_len num_heads"], decay: Float[Array, "batch seq_len num_heads"] | None = None, initial_state: Float[Array, "batch num_heads qk_head_dim v_head_dim"] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: KernelDeltaAttentionConfig, ) -> ( Float[Array, "batch seq_len num_heads v_head_dim"] | tuple[ Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "batch num_heads qk_head_dim v_head_dim"], ] ): """Execute Kernel Delta Attention operation. Args: query: Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim] key: Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim] value: Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim] beta: Per-token learning rate for delta updates Shape: [batch, seq_len, num_heads] decay: Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] initial_state: Optional initial memory state Shape: [batch, num_heads, qk_head_dim, v_head_dim] softmax_scale: Scaling factor for queries (default: head_dim^-0.5) chunk_size: Block size for chunked algorithm use_qk_l2norm: Apply L2 normalization to queries and keys use_chunked: Use chunked (True) or recurrent (False) algorithm return_state: If True, return (output, state) tuple platform: Override platform selection cfg: Kernel configuration object Returns: If return_state is False: output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True: Tuple of (output, final_state) where final_state has shape [batch, num_heads, qk_head_dim, v_head_dim] """ if platform is not None: cfg = KernelDeltaAttentionConfig( platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) out, final_state = impl( query=query, key=key, value=value, beta=beta, decay=decay, softmax_scale=softmax_scale, chunk_size=chunk_size, initial_state=initial_state, use_qk_l2norm=use_qk_l2norm, use_chunked=use_chunked, ) if return_state: return out, final_state return out
[docs] def heuristic_cfg(self, inv: Invocation[KernelDeltaAttentionConfig, Array]) -> KernelDeltaAttentionConfig: """Provide default configuration based on heuristics. Args: inv: Invocation object (unused, config is static) Returns: Default configuration with auto platform selection """ del inv # Unused return KernelDeltaAttentionConfig(platform="auto", backend="any")
[docs] def candidate_cfgs(self, inv: Invocation[KernelDeltaAttentionConfig, Array]): """Generate candidate configurations for autotuning. Args: inv: Invocation object (unused) Returns: Empty list - KDA uses XLA without tunable block sizes Note: KDA currently has a single XLA implementation without tunable parameters, so autotuning is not applicable. """ del inv # Unused return []
_executor: Executor[KernelDeltaAttentionConfig, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy(allow_autotune=True, cache_miss_fallback="heuristics", validate_backward=True), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("kernel_delta_attention"), ) )
[docs]def kernel_delta_attention( query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_heads qk_head_dim"], value: Float[Array, "batch seq_len num_heads v_head_dim"], beta: Float[Array, "batch seq_len num_heads"], decay: Float[Array, "batch seq_len num_heads"] | None = None, initial_state: Float[Array, "batch num_heads qk_head_dim v_head_dim"] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: KernelDeltaAttentionConfig | None = None, ) -> ( Float[Array, "batch seq_len num_heads v_head_dim"] | tuple[ Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "batch num_heads qk_head_dim v_head_dim"], ] ): """Execute Kernel Delta Attention (KDA) with automatic optimization. KDA is a linear attention variant that uses delta rule updates to maintain an associative key-value memory matrix. It provides O(N) complexity while supporting efficient stateful incremental inference. The core recurrence: h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t Args: query: Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim] key: Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim] value: Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim] beta: Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads] decay: Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention) initial_state: Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim] softmax_scale: Scaling factor for queries. If None, uses head_dim^-0.5 chunk_size: Block size for chunked algorithm (default: 64) use_qk_l2norm: If True, apply L2 normalization to queries and keys before attention. Improves numerical stability (default: True) use_chunked: If True, use chunked algorithm (faster for training); else use recurrent scan (more memory efficient) return_state: If True, return (output, final_state) tuple platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Optional kernel configuration Returns: If return_state is False: output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True: Tuple of: - output: Attention output [batch, seq_len, num_heads, v_head_dim] - final_state: Final memory state for incremental inference [batch, num_heads, qk_head_dim, v_head_dim] Example: >>> import jax.numpy as jnp >>> from jax import random >>> >>> # Training forward pass >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output = kernel_delta_attention(q, k, v, beta, decay) >>> output.shape (2, 64, 8, 32) >>> >>> # Inference with state >>> output, state = kernel_delta_attention(q, k, v, beta, decay, return_state=True) >>> >>> # Next token generation >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, ... initial_state=state, return_state=True, ... ) """ return _executor( KernelDeltaAttention(), query=query, key=key, value=value, beta=beta, decay=decay, softmax_scale=softmax_scale, chunk_size=chunk_size, initial_state=initial_state, use_qk_l2norm=use_qk_l2norm, use_chunked=use_chunked, return_state=return_state, platform=platform, _cfg=cfg, )
kda_attention = kernel_delta_attention