ejkernel.modules.operations.kernel_delta_attention#
Kernel Delta Attention (KDA) operation module with automatic optimization.
This module provides the KernelDeltaAttention operation, a linear attention variant that uses delta rule updates to maintain a key-value memory matrix. KDA achieves O(N) complexity while supporting efficient incremental inference.
- Key characteristics of KDA:
Memory matrix: [num_heads, head_dim, value_dim] per batch
Delta updates: Only stores what’s new in each value
Decay mechanism: Controls memory retention over time
Beta parameter: Per-token learning rate for updates
- The core recurrence:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
- Algorithm modes:
Chunked (default): Parallel within chunks for training efficiency
Recurrent: Pure sequential scan for memory efficiency
Single-step: Optimized path for autoregressive generation
- Features:
Automatic platform selection (XLA primary)
Configuration caching for consistent performance
L2 normalization option for stability
State passthrough for incremental inference
Example
>>> from ejkernel.modules.operations import kernel_delta_attention
>>>
>>> # Training forward pass
>>> output = kernel_delta_attention(
... query, key, value, beta, decay,
... chunk_size=64, use_chunked=True,
... )
>>>
>>> # Inference with state
>>> output, state = kernel_delta_attention(
... query, key, value, beta, decay,
... return_state=True,
... )
>>> # Next token
>>> output_new, state_new = kernel_delta_attention(
... q_new, k_new, v_new, beta_new, decay_new,
... initial_state=state, return_state=True,
... )
References
Delta Networks: https://arxiv.org/abs/1612.04859
Linear Transformers: https://arxiv.org/abs/2006.16236
- class ejkernel.modules.operations.kernel_delta_attention.KernelDeltaAttention[source]#
Bases:
Kernel[KernelDeltaAttentionConfig,Array]Kernel Delta Attention (KDA) operation.
A linear attention mechanism using delta rule updates to maintain an associative memory matrix. Supports both training (chunked) and inference (single-step) modes with O(N) complexity.
The operation maintains a [head_dim, value_dim] memory matrix per head that stores key-value associations. The delta update mechanism ensures only novel information is added, improving memory efficiency.
- op_id#
Operation identifier for registry lookup (“kernel_delta_attention”)
- Type
str
- candidate_cfgs(inv: Invocation[KernelDeltaAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object (unused)
- Returns
Empty list - KDA uses XLA without tunable block sizes
Note
KDA currently has a single XLA implementation without tunable parameters, so autotuning is not applicable.
- get_impl(cfg: KernelDeltaAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for KDA
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[KernelDeltaAttentionConfig, Array]) KernelDeltaAttentionConfig[source]#
Provide default configuration based on heuristics.
- Parameters
inv – Invocation object (unused, config is static)
- Returns
Default configuration with auto platform selection
- run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: KernelDeltaAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#
Execute Kernel Delta Attention operation.
- Parameters
query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]
key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]
value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]
beta – Per-token learning rate for delta updates Shape: [batch, seq_len, num_heads]
decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads]
initial_state – Optional initial memory state Shape: [batch, num_heads, qk_head_dim, v_head_dim]
softmax_scale – Scaling factor for queries (default: head_dim^-0.5)
chunk_size – Block size for chunked algorithm
use_qk_l2norm – Apply L2 normalization to queries and keys
use_chunked – Use chunked (True) or recurrent (False) algorithm
return_state – If True, return (output, state) tuple
platform – Override platform selection
cfg – Kernel configuration object
- Returns
output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:
Tuple of (output, final_state) where final_state has shape [batch, num_heads, qk_head_dim, v_head_dim]
- Return type
If return_state is False
- ejkernel.modules.operations.kernel_delta_attention.kda_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.KernelDeltaAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']]#
Execute Kernel Delta Attention (KDA) with automatic optimization.
KDA is a linear attention variant that uses delta rule updates to maintain an associative key-value memory matrix. It provides O(N) complexity while supporting efficient stateful incremental inference.
- The core recurrence:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
- Parameters
query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]
key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]
value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]
beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]
decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)
initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]
softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5
chunk_size – Block size for chunked algorithm (default: 64)
use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves numerical stability (default: True)
use_chunked – If True, use chunked algorithm (faster for training); else use recurrent scan (more memory efficient)
return_state – If True, return (output, final_state) tuple
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional kernel configuration
- Returns
output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:
- Tuple of:
output: Attention output [batch, seq_len, num_heads, v_head_dim]
- final_state: Final memory state for incremental inference
[batch, num_heads, qk_head_dim, v_head_dim]
- Return type
If return_state is False
Example
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Training forward pass >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output = kernel_delta_attention(q, k, v, beta, decay) >>> output.shape (2, 64, 8, 32) >>> >>> # Inference with state >>> output, state = kernel_delta_attention(q, k, v, beta, decay, return_state=True) >>> >>> # Next token generation >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, ... initial_state=state, return_state=True, ... )
- ejkernel.modules.operations.kernel_delta_attention.kernel_delta_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], beta: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], decay: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim'] | None = None, /, *, softmax_scale: float | None = None, chunk_size: int = 64, use_qk_l2norm: bool = True, use_chunked: bool = True, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.KernelDeltaAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads qk_head_dim v_head_dim']][source]#
Execute Kernel Delta Attention (KDA) with automatic optimization.
KDA is a linear attention variant that uses delta rule updates to maintain an associative key-value memory matrix. It provides O(N) complexity while supporting efficient stateful incremental inference.
- The core recurrence:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
- Parameters
query – Query tensor for attention Shape: [batch, seq_len, num_heads, qk_head_dim]
key – Key tensor for memory addressing Shape: [batch, seq_len, num_heads, qk_head_dim]
value – Value tensor to store in memory Shape: [batch, seq_len, num_heads, v_head_dim]
beta – Per-token learning rate for delta updates (typically in [0, 1]) Shape: [batch, seq_len, num_heads]
decay – Per-token decay for memory retention (should be <= 0) Shape: [batch, seq_len, num_heads] If None, defaults to zeros (no decay, full retention)
initial_state – Optional initial memory state for incremental inference Shape: [batch, num_heads, qk_head_dim, v_head_dim]
softmax_scale – Scaling factor for queries. If None, uses head_dim^-0.5
chunk_size – Block size for chunked algorithm (default: 64)
use_qk_l2norm – If True, apply L2 normalization to queries and keys before attention. Improves numerical stability (default: True)
use_chunked – If True, use chunked algorithm (faster for training); else use recurrent scan (more memory efficient)
return_state – If True, return (output, final_state) tuple
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional kernel configuration
- Returns
output: Attention output [batch, seq_len, num_heads, v_head_dim] If return_state is True:
- Tuple of:
output: Attention output [batch, seq_len, num_heads, v_head_dim]
- final_state: Final memory state for incremental inference
[batch, num_heads, qk_head_dim, v_head_dim]
- Return type
If return_state is False
Example
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Training forward pass >>> batch, seq_len, num_heads, head_dim = 2, 64, 8, 32 >>> key = random.PRNGKey(0) >>> q = random.normal(random.fold_in(key, 0), (batch, seq_len, num_heads, head_dim)) >>> k = random.normal(random.fold_in(key, 1), (batch, seq_len, num_heads, head_dim)) >>> v = random.normal(random.fold_in(key, 2), (batch, seq_len, num_heads, head_dim)) >>> beta = jax.nn.sigmoid(random.normal(random.fold_in(key, 3), (batch, seq_len, num_heads))) >>> decay = random.normal(random.fold_in(key, 4), (batch, seq_len, num_heads)) * -0.01 >>> >>> output = kernel_delta_attention(q, k, v, beta, decay) >>> output.shape (2, 64, 8, 32) >>> >>> # Inference with state >>> output, state = kernel_delta_attention(q, k, v, beta, decay, return_state=True) >>> >>> # Next token generation >>> q_new = random.normal(random.fold_in(key, 5), (batch, 1, num_heads, head_dim)) >>> k_new = random.normal(random.fold_in(key, 6), (batch, 1, num_heads, head_dim)) >>> v_new = random.normal(random.fold_in(key, 7), (batch, 1, num_heads, head_dim)) >>> beta_new = jax.nn.sigmoid(random.normal(random.fold_in(key, 8), (batch, 1, num_heads))) >>> decay_new = random.normal(random.fold_in(key, 9), (batch, 1, num_heads)) * -0.01 >>> >>> output_new, state_new = kernel_delta_attention( ... q_new, k_new, v_new, beta_new, decay_new, ... initial_state=state, return_state=True, ... )