ejkernel.modules.operations.kernel_delta_attention#

Kernel Delta Attention (KDA) operation module with automatic optimization.

This module provides the KernelDeltaAttention operation, a linear attention variant that uses delta rule updates to maintain a key-value memory matrix. KDA achieves O(N) complexity while supporting efficient incremental inference.

Key characteristics of KDA:
  • Memory matrix: [num_heads, head_dim, value_dim] per batch

  • Delta updates: Only stores what’s new in each value

  • Decay mechanism: Controls memory retention over time

  • Beta parameter: Per-token learning rate for updates

The core recurrence:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Algorithm modes:
  • Chunked (default): Parallel within chunks for training efficiency

  • Recurrent: Pure sequential scan for memory efficiency

  • Single-step: Optimized path for autoregressive generation

Features:
  • Automatic platform selection (XLA primary)

  • Configuration caching for consistent performance

  • L2 normalization option for stability

  • State passthrough for incremental inference

Example

>>> from ejkernel.modules.operations import kernel_delta_attention
>>>
>>> # Training forward pass
>>> output = kernel_delta_attention(
...     query, key, value, beta, decay,
...     chunk_size=64, use_chunked=True,
... )
>>>
>>> # Inference with state
>>> output, state = kernel_delta_attention(
...     query, key, value, beta, decay,
...     return_state=True,
... )
>>> # Next token
>>> output_new, state_new = kernel_delta_attention(
...     q_new, k_new, v_new, beta_new, decay_new,
...     initial_state=state, return_state=True,
... )

References

class ejkernel.modules.operations.kernel_delta_attention.KernelDeltaAttention[source]#

Bases: Kernel[KernelDeltaAttentionConfig, Array]

Kernel Delta Attention (KDA) operation.

A linear attention mechanism using delta rule updates to maintain an associative memory matrix. Supports both training (chunked) and inference (single-step) modes with O(N) complexity.

The operation maintains a [head_dim, value_dim] memory matrix per head that stores key-value associations. The delta update mechanism ensures only novel information is added, improving memory efficiency.

op_id#

Operation identifier for registry lookup (“kernel_delta_attention”)

Type

str

candidate_cfgs(inv: Invocation[KernelDeltaAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object (unused)

Returns

Empty list - KDA uses XLA without tunable block sizes

Note

KDA currently has a single XLA implementation without tunable parameters, so autotuning is not applicable.

get_impl(cfg: KernelDeltaAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for KDA

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[KernelDeltaAttentionConfig, Array]) KernelDeltaAttentionConfig[source]#

Provide default configuration based on heuristics.

Parameters

inv – Invocation object (unused, config is static)

Returns

Default configuration with auto platform selection

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: KernelDeltaAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#

Execute Kernel Delta Attention operation.

Parameters
  • query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]

  • key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]

  • value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]

  • beta – Per-token learning rate for delta updates Shape: [batch, seq_len, num_heads]

  • decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads]

  • initial_state – Optional initial memory state Shape: [batch, num_heads, qk_head_dim, v_head_dim]

  • softmax_scale – Scaling factor for queries (default: head_dim^-0.5)

  • chunk_size – Block size for chunked algorithm

  • use_qk_l2norm – Apply L2 normalization to queries and keys

  • use_chunked – Use chunked (True) or recurrent (False) algorithm

  • return_state – If True, return (output, state) tuple

  • platform – Override platform selection

  • cfg – Kernel configuration object

Returns

output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:

Tuple of (output, final_state) where final_state has shape [batch, num_heads, qk_head_dim, v_head_dim]

Return type

If return_state is False

ejkernel.modules.operations.kernel_delta_attention.kda_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.KernelDeltaAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']]#

Execute Kernel Delta Attention (KDA) with automatic optimization.

KDA is a linear attention variant that uses delta rule updates to maintain an associative key-value memory matrix. It provides O(N) complexity while supporting efficient stateful incremental inference.

The core recurrence:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Parameters
  • query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]

  • key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]

  • value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]

  • beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]

  • decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)

  • initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]

  • softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5

  • chunk_size – Block size for chunked algorithm (default: 64)

  • use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves numerical stability (default: True)

  • use_chunked – If True, use chunked algorithm (faster for training); else use recurrent scan (more memory efficient)

  • return_state – If True, return (output, final_state) tuple

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional kernel configuration

Returns

output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:

Tuple of:
  • output: Attention output [batch, seq_len, num_heads, v_head_dim]

  • final_state: Final memory state for incremental inference

    [batch, num_heads, qk_head_dim, v_head_dim]

Return type

If return_state is False

Example

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Training forward pass
>>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32
>>> key = random.PRNGKey(0)
>>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim))
>>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim))
>>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim))
>>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads)))
>>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01
>>>
>>> output = kernel_delta_attention(q, k, v, beta, decay)
>>> output.shape
(2, 64, 8, 32)
>>>
>>> # Inference with state
>>> output, state = kernel_delta_attention(q, k, v, beta, decay, return_state=True)
>>>
>>> # Next token generation
>>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim))
>>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim))
>>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim))
>>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads)))
>>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01
>>>
>>> output_new, state_new = kernel_delta_attention(
...     q_new, k_new, v_new, beta_new, decay_new,
...     initial_state=state, return_state=True,
... )
ejkernel.modules.operations.kernel_delta_attention.kernel_delta_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.KernelDeltaAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#

Execute Kernel Delta Attention (KDA) with automatic optimization.

KDA is a linear attention variant that uses delta rule updates to maintain an associative key-value memory matrix. It provides O(N) complexity while supporting efficient stateful incremental inference.

The core recurrence:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Parameters
  • query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]

  • key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]

  • value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]

  • beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]

  • decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)

  • initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]

  • softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5

  • chunk_size – Block size for chunked algorithm (default: 64)

  • use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves numerical stability (default: True)

  • use_chunked – If True, use chunked algorithm (faster for training); else use recurrent scan (more memory efficient)

  • return_state – If True, return (output, final_state) tuple

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional kernel configuration

Returns

output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:

Tuple of:
  • output: Attention output [batch, seq_len, num_heads, v_head_dim]

  • final_state: Final memory state for incremental inference

    [batch, num_heads, qk_head_dim, v_head_dim]

Return type

If return_state is False

Example

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Training forward pass
>>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32
>>> key = random.PRNGKey(0)
>>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim))
>>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim))
>>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim))
>>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads)))
>>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01
>>>
>>> output = kernel_delta_attention(q, k, v, beta, decay)
>>> output.shape
(2, 64, 8, 32)
>>>
>>> # Inference with state
>>> output, state = kernel_delta_attention(q, k, v, beta, decay, return_state=True)
>>>
>>> # Next token generation
>>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim))
>>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim))
>>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim))
>>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads)))
>>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01
>>>
>>> output_new, state_new = kernel_delta_attention(
...     q_new, k_new, v_new, beta_new, decay_new,
...     initial_state=state, return_state=True,
... )