ejkernel.kernels._xla.recurrent._interface#
Recurrent attention interface for linear-time sequence processing.
This module provides the public API for recurrent linear attention with O(N) complexity. Supports various gating mechanisms (GLA, Lightning) and provides custom VJP for efficient gradient computation.
- ejkernel.kernels._xla.recurrent._interface.recurrent(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, '... num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#
Recurrent linear attention with O(N) complexity using JAX/XLA.
This implements linear attention as a recurrent process, maintaining a hidden state that accumulates key-value information sequentially. Unlike standard O(N²) attention, this achieves O(N) complexity by processing tokens one at a time.
- The core update is:
h_t = decay_t * h_{t-1} + k_t^T ⊗ v_t o_t = h_t @ q_t
Supports various gating mechanisms for different attention variants.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_heads, head_dim]
value – Value tensor [batch, seq_len, num_heads, head_dim]
g – Optional gate tensor for GLA-style gating [batch, seq_len, num_heads, head_dim]
g_gamma – Optional per-head decay factor [num_heads] for Lightning attention
gk – Optional gate applied to keys [batch, seq_len, num_heads, head_dim]
gv – Optional gate applied to values [batch, seq_len, num_heads, head_dim]
softmax_scale – Query scaling factor. If None, defaults to 1/sqrt(head_dim)
initial_state – Initial hidden state [batch, num_heads, head_dim, head_dim]
reverse – If True, process sequence in reverse order
cu_seqlens – Cumulative sequence lengths for variable-length inputs [num_seqs+1]
- Returns
output: Attention output [batch, seq_len, num_heads, head_dim]
final_state: Final hidden state [batch, num_heads, head_dim, head_dim]
- Return type
Tuple of
Examples
>>> >>> query = jnp.ones((2, 100, 8, 64)) >>> key = jnp.ones((2, 100, 8, 64)) >>> value = jnp.ones((2, 100, 8, 64)) >>> output, final_state = recurrent(query, key, value) >>> output.shape (2, 100, 8, 64)
>>> >>> g = jnp.ones((2, 100, 8, 64)) >>> output, final_state = recurrent(query, key, value, g=g)
>>> >>> g_gamma = -jnp.arange(8, dtype=jnp.float32) * 0.1 >>> output, final_state = recurrent(query, key, value, g_gamma=g_gamma)
>>> >>> query = jnp.ones((150, 8, 64)) >>> key = jnp.ones((150, 8, 64)) >>> value = jnp.ones((150, 8, 64)) >>> cu_seqlens = jnp.array([0, 50, 100, 150]) >>> output, final_state = recurrent(query, key, value, cu_seqlens=cu_seqlens) >>> output.shape (150, 8, 64) >>> final_state.shape (3, 8, 64, 64)