ejkernel.kernels._xla.kernel_delta_attention._interface

ejkernel.kernels._xla.kernel_delta_attention._interface#

Kernel Delta Attention interface for linear-time attention with delta updates.

This module provides the public API for KDA (Kernel Delta Attention), a linear attention variant using delta rule updates for memory management. Supports chunked, recurrent, and single-step computation modes.

ejkernel.kernels._xla.kernel_delta_attention._interface.kda(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, *, softmax_scale: float | None = None, chunk_size: int = 64, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, use_qk_l2norm: bool = True, use_chunked: bool = True) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']]#

Kernel Delta Attention (KDA) linear attention using XLA backend.

KDA is a linear attention variant that maintains a key-value memory matrix and uses delta updates to efficiently store and retrieve information. It combines ideas from linear attention and delta networks for O(N) complexity.

The core recurrence is:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Where:
  • h_t is the [head_dim, value_dim] memory matrix per head

  • exp(decay_t) controls memory retention (decay <= 0 for stability)

  • beta_t controls the learning rate for delta updates

  • The delta term (v_t - h_{t-1} @ k_t) computes what’s new in v_t

Algorithm Modes:
  • Chunked (default): Parallel within chunks, sequential across chunks. Better throughput for training with moderate sequence lengths.

  • Recurrent: Pure sequential scan. Lower memory, good for inference.

  • Single-step: Optimized path when seq_len=1 with initial_state.

Parameters
  • query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]

  • key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]

  • value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]

  • beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]

  • decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)

  • softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5

  • chunk_size – Block size for chunked algorithm (default: 64)

  • initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]

  • use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves stability (default: True)

  • use_chunked – If True, use chunked algorithm; else use recurrent scan Chunked is faster for training, recurrent for long inference

Returns

  • output: Attention output

    Shape: [batch, seq_len, num_heads, v_head_dim]

  • final_state: Final memory state for incremental inference

    Shape: [batch, num_heads, qk_head_dim, v_head_dim]

Return type

Tuple of

Example

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Basic usage
>>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32
>>> key = random.PRNGKey(0)
>>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim))
>>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim))
>>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim))
>>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads)))
>>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01
>>>
>>> output, state = kernel_delta_attention(q, k, v, beta, decay, chunk_size=16)
>>> output.shape
(2, 64, 8, 32)
>>>
>>> # Incremental inference
>>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim))
>>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim))
>>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim))
>>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads)))
>>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01
>>>
>>> output_new, state_new = kernel_delta_attention(
...     q_new, k_new, v_new, beta_new, decay_new, initial_state=state
... )

References

ejkernel.kernels._xla.kernel_delta_attention._interface.kda_decay(gate: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], A_log: Float[jaxlib._jax.Array, 'num_heads'], dt_bias: Float[jaxlib._jax.Array, 'num_heads']) Float[jaxlib._jax.Array, 'batch seq_len num_heads'][source]#

Compute KDA per-token decay from gate, A_log, and dt_bias.

This function computes the decay term used in Kernel Delta Attention, following the Mamba-style discretization where decay controls how much of the previous state is retained.

The computation is:

A = -exp(A_log) # Ensure A is negative for stability decay = A * softplus(gate + dt_bias)

Parameters
  • gate – Gating signal from input projection Shape: [batch, seq_len, num_heads]

  • A_log – Learnable log-scale decay parameter (typically initialized near 0) Shape: [num_heads]

  • dt_bias – Learnable bias added to gate before softplus Shape: [num_heads]

Returns

Per-token decay values (always <= 0 for stable state decay)

Shape: [batch, seq_len, num_heads]

Example

>>> gate = jnp.zeros((2, 10, 4))  # batch=2, seq_len=10, num_heads=4
>>> A_log = jnp.zeros((4,))
>>> dt_bias = jnp.zeros((4,))
>>> decay = kda_decay(gate, A_log, dt_bias)
>>> assert jnp.all(decay <= 0)  # Decay is always non-positive
ejkernel.kernels._xla.kernel_delta_attention._interface.kernel_delta_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, *, softmax_scale: float | None = None, chunk_size: int = 64, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, use_qk_l2norm: bool = True, use_chunked: bool = True) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#

Kernel Delta Attention (KDA) linear attention using XLA backend.

KDA is a linear attention variant that maintains a key-value memory matrix and uses delta updates to efficiently store and retrieve information. It combines ideas from linear attention and delta networks for O(N) complexity.

The core recurrence is:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Where:
  • h_t is the [head_dim, value_dim] memory matrix per head

  • exp(decay_t) controls memory retention (decay <= 0 for stability)

  • beta_t controls the learning rate for delta updates

  • The delta term (v_t - h_{t-1} @ k_t) computes what’s new in v_t

Algorithm Modes:
  • Chunked (default): Parallel within chunks, sequential across chunks. Better throughput for training with moderate sequence lengths.

  • Recurrent: Pure sequential scan. Lower memory, good for inference.

  • Single-step: Optimized path when seq_len=1 with initial_state.

Parameters
  • query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]

  • key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]

  • value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]

  • beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]

  • decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)

  • softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5

  • chunk_size – Block size for chunked algorithm (default: 64)

  • initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]

  • use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves stability (default: True)

  • use_chunked – If True, use chunked algorithm; else use recurrent scan Chunked is faster for training, recurrent for long inference

Returns

  • output: Attention output

    Shape: [batch, seq_len, num_heads, v_head_dim]

  • final_state: Final memory state for incremental inference

    Shape: [batch, num_heads, qk_head_dim, v_head_dim]

Return type

Tuple of

Example

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Basic usage
>>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32
>>> key = random.PRNGKey(0)
>>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim))
>>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim))
>>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim))
>>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads)))
>>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01
>>>
>>> output, state = kernel_delta_attention(q, k, v, beta, decay, chunk_size=16)
>>> output.shape
(2, 64, 8, 32)
>>>
>>> # Incremental inference
>>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim))
>>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim))
>>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim))
>>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads)))
>>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01
>>>
>>> output_new, state_new = kernel_delta_attention(
...     q_new, k_new, v_new, beta_new, decay_new, initial_state=state
... )

References