ejkernel.kernels._xla.kernel_delta_attention._xla_impl_fwd#
Forward pass implementations for Kernel Delta Attention (KDA).
This module provides three forward pass variants for KDA:
Recurrent (_recurrent_kda_fwd): Pure sequential scan with O(L) time complexity. Best for very long sequences or memory-constrained inference.
Chunked (_chunk_kda_fwd): Hybrid approach with parallel intra-chunk computation and sequential inter-chunk state propagation. Best for training with moderate sequences.
Single-step (_single_step_kda_fwd): Optimized path for seq_len=1 during autoregressive inference.
- The KDA update rule:
h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t
Where h_t is the [head_dim, value_dim] memory matrix that stores key-value associations and supports efficient retrieval via query projection.