ejkernel.kernels._xla.kernel_delta_attention._xla_impl_fwd

ejkernel.kernels._xla.kernel_delta_attention._xla_impl_fwd#

Forward pass implementations for Kernel Delta Attention (KDA).

This module provides three forward pass variants for KDA:

  1. Recurrent (_recurrent_kda_fwd): Pure sequential scan with O(L) time complexity. Best for very long sequences or memory-constrained inference.

  2. Chunked (_chunk_kda_fwd): Hybrid approach with parallel intra-chunk computation and sequential inter-chunk state propagation. Best for training with moderate sequences.

  3. Single-step (_single_step_kda_fwd): Optimized path for seq_len=1 during autoregressive inference.

The KDA update rule:

h_t = exp(decay_t) * h_{t-1} + k_t ⊗ (beta_t * (v_t - h_{t-1} @ k_t)) o_t = h_t @ q_t

Where h_t is the [head_dim, value_dim] memory matrix that stores key-value associations and supports efficient retrieval via query projection.