ejkernel.kernels._xla.recurrent._interface

Contents

ejkernel.kernels._xla.recurrent._interface#

Recurrent attention interface for linear-time sequence processing.

This module provides the public API for recurrent linear attention with O(N) complexity. Supports various gating mechanisms (GLA, Lightning) and provides custom VJP for efficient gradient computation.

ejkernel.kernels._xla.recurrent._interface.recurrent(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, '... num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#

Recurrent linear attention with O(N) complexity using JAX/XLA.

This implements linear attention as a recurrent process, maintaining a hidden state that accumulates key-value information sequentially. Unlike standard O(N²) attention, this achieves O(N) complexity by processing tokens one at a time.

The core update is:

h_t = decay_t * h_{t-1} + k_t^T ⊗ v_t o_t = h_t @ q_t

Supports various gating mechanisms for different attention variants.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_heads, head_dim]

  • g – Optional gate tensor for GLA-style gating [batch, seq_len, num_heads, head_dim]

  • g_gamma – Optional per-head decay factor [num_heads] for Lightning attention

  • gk – Optional gate applied to keys [batch, seq_len, num_heads, head_dim]

  • gv – Optional gate applied to values [batch, seq_len, num_heads, head_dim]

  • softmax_scale – Query scaling factor. If None, defaults to 1/sqrt(head_dim)

  • initial_state – Initial hidden state [batch, num_heads, head_dim, head_dim]

  • reverse – If True, process sequence in reverse order

  • cu_seqlens – Cumulative sequence lengths for variable-length inputs [num_seqs+1]

Returns

  • output: Attention output [batch, seq_len, num_heads, head_dim]

  • final_state: Final hidden state [batch, num_heads, head_dim, head_dim]

Return type

Tuple of

Examples

>>>
>>> query = jnp.ones((2, 100, 8, 64))
>>> key = jnp.ones((2, 100, 8, 64))
>>> value = jnp.ones((2, 100, 8, 64))
>>> output, final_state = recurrent(query, key, value)
>>> output.shape
(2, 100, 8, 64)
>>>
>>> g = jnp.ones((2, 100, 8, 64))
>>> output, final_state = recurrent(query, key, value, g=g)
>>>
>>> g_gamma = -jnp.arange(8, dtype=jnp.float32) * 0.1
>>> output, final_state = recurrent(query, key, value, g_gamma=g_gamma)
>>>
>>> query = jnp.ones((150, 8, 64))
>>> key = jnp.ones((150, 8, 64))
>>> value = jnp.ones((150, 8, 64))
>>> cu_seqlens = jnp.array([0, 50, 100, 150])
>>> output, final_state = recurrent(query, key, value, cu_seqlens=cu_seqlens)
>>> output.shape
(150, 8, 64)
>>> final_state.shape
(3, 8, 64, 64)