ejkernel.kernels._xla.kernel_delta_attention._interface#
Kernel Delta Attention interface for linear-time attention with delta updates.
This module provides the public API for KDA (Kernel Delta Attention), a linear attention variant using delta rule updates for memory management. Supports chunked, recurrent, and single-step computation modes.
- ejkernel.kernels._xla.kernel_delta_attention._interface.kda(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, *, softmax_scale: float | None = None, chunk_size: int = 64, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, use_qk_l2norm: bool = True, use_chunked: bool = True) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']]#
Kernel Delta Attention (KDA) linear attention using XLA backend.
KDA is a linear attention variant that maintains a key-value memory matrix and uses delta updates to efficiently store and retrieve information. It combines ideas from linear attention and delta networks for O(N) complexity.
- The core recurrence is:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
- Where:
h_t is the [head_dim, value_dim] memory matrix per head
exp(decay_t) controls memory retention (decay <= 0 for stability)
beta_t controls the learning rate for delta updates
The delta term (v_t - h_{t-1} @ k_t) computes what’s new in v_t
- Algorithm Modes:
Chunked (default): Parallel within chunks, sequential across chunks. Better throughput for training with moderate sequence lengths.
Recurrent: Pure sequential scan. Lower memory, good for inference.
Single-step: Optimized path when seq_len=1 with initial_state.
- Parameters
query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]
key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]
value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]
beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]
decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)
softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5
chunk_size – Block size for chunked algorithm (default: 64)
initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]
use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves stability (default: True)
use_chunked – If True, use chunked algorithm; else use recurrent scan Chunked is faster for training, recurrent for long inference
- Returns
- output: Attention output
Shape: [batch, seq_len, num_heads, v_head_dim]
- final_state: Final memory state for incremental inference
Shape: [batch, num_heads, qk_head_dim, v_head_dim]
- Return type
Tuple of
Example
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Basic usage >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output, state = kernel_delta_attention(q, k, v, beta, decay, chunk_size=16) >>> output.shape (2, 64, 8, 32) >>> >>> # Incremental inference >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, initial_state=state ... )
References
Delta Networks: https://arxiv.org/abs/1612.04859
Linear Transformers: https://arxiv.org/abs/2006.16236
- ejkernel.kernels._xla.kernel_delta_attention._interface.kda_decay(gate: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], A_log: Float[jaxlib._jax.Array, 'num_heads'], dt_bias: Float[jaxlib._jax.Array, 'num_heads']) Float[jaxlib._jax.Array, 'batch seq_len num_heads'][source]#
Compute KDA per-token decay from gate, A_log, and dt_bias.
This function computes the decay term used in Kernel Delta Attention, following the Mamba-style discretization where decay controls how much of the previous state is retained.
- The computation is:
A = -exp(A_log) # Ensure A is negative for stability decay = A * softplus(gate + dt_bias)
- Parameters
gate – Gating signal from input projection Shape: [batch, seq_len, num_heads]
A_log – Learnable log-scale decay parameter (typically initialized near 0) Shape: [num_heads]
dt_bias – Learnable bias added to gate before softplus Shape: [num_heads]
- Returns
- Per-token decay values (always <= 0 for stable state decay)
Shape: [batch, seq_len, num_heads]
Example
>>> gate = jnp.zeros((2, 10, 4)) # batch=2, seq_len=10, num_heads=4 >>> A_log = jnp.zeros((4,)) >>> dt_bias = jnp.zeros((4,)) >>> decay = kda_decay(gate, A_log, dt_bias) >>> assert jnp.all(decay <= 0) # Decay is always non-positive
- ejkernel.kernels._xla.kernel_delta_attention._interface.kernel_delta_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, *, softmax_scale: float | None = None, chunk_size: int = 64, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, use_qk_l2norm: bool = True, use_chunked: bool = True) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#
Kernel Delta Attention (KDA) linear attention using XLA backend.
KDA is a linear attention variant that maintains a key-value memory matrix and uses delta updates to efficiently store and retrieve information. It combines ideas from linear attention and delta networks for O(N) complexity.
- The core recurrence is:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
- Where:
h_t is the [head_dim, value_dim] memory matrix per head
exp(decay_t) controls memory retention (decay <= 0 for stability)
beta_t controls the learning rate for delta updates
The delta term (v_t - h_{t-1} @ k_t) computes what’s new in v_t
- Algorithm Modes:
Chunked (default): Parallel within chunks, sequential across chunks. Better throughput for training with moderate sequence lengths.
Recurrent: Pure sequential scan. Lower memory, good for inference.
Single-step: Optimized path when seq_len=1 with initial_state.
- Parameters
query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]
key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]
value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]
beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]
decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)
softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5
chunk_size – Block size for chunked algorithm (default: 64)
initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]
use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves stability (default: True)
use_chunked – If True, use chunked algorithm; else use recurrent scan Chunked is faster for training, recurrent for long inference
- Returns
- output: Attention output
Shape: [batch, seq_len, num_heads, v_head_dim]
- final_state: Final memory state for incremental inference
Shape: [batch, num_heads, qk_head_dim, v_head_dim]
- Return type
Tuple of
Example
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Basic usage >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output, state = kernel_delta_attention(q, k, v, beta, decay, chunk_size=16) >>> output.shape (2, 64, 8, 32) >>> >>> # Incremental inference >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, initial_state=state ... )
References
Delta Networks: https://arxiv.org/abs/1612.04859
Linear Transformers: https://arxiv.org/abs/2006.16236