ejkernel.kernels._triton.recurrent._interface#
Recurrent linear attention implementation using Triton kernels.
This module provides a highly optimized, GPU-accelerated implementation of recurrent linear attention mechanisms. Unlike traditional attention mechanisms with O(N²) complexity, recurrent linear attention processes sequences step-by-step with O(N) complexity, making it ideal for very long sequences.
The implementation is general enough to support various linear attention variants including: - Gated Linear Attention (GLA) via the g parameter - Lightning Attention via the g_gamma parameter - Standard linear attention without gating
Key features: - O(N) time complexity for sequence processing - Custom Triton kernels for GPU acceleration - Support for variable-length sequences via cumulative sequence lengths - Bidirectional processing via the reverse parameter - Stateful processing via initial_state parameter for chunked computation
Example
>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.recurrent import recurrent
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>> output, final_state = recurrent(q, k, v)
- ejkernel.kernels._triton.recurrent._interface.recurrent(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, '... num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#
Computes a general recurrent linear attention using a custom Triton kernel.
This function provides a highly optimized and flexible implementation of recurrent linear attention. It processes sequences step-by-step, resulting in O(N) complexity, which is ideal for long sequences. The implementation is general enough to support various linear attention mechanisms by configuring the gate inputs.
It supports both standard batch processing and variable-length sequence processing using cumulative sequence lengths (cu_seqlens).
- Parameters
query – The query tensor.
key – The key tensor.
value – The value tensor.
g – Optional gate tensor for Gated Linear Attention (GLA) style gating.
g_gamma – Optional decay factor, used for mechanisms like Lightning Attention where the decay is fixed per-head or per-layer.
gk – Optional gate tensor applied element-wise to keys.
gv – Optional gate tensor applied element-wise to values.
softmax_scale – A scaling factor applied to the query. If None, it defaults to 1 / sqrt(head_dim).
initial_state – The initial hidden state for the recurrence. This is useful for chunked processing of very long sequences or for stateful autoregressive decoding.
reverse – If True, the sequence is processed in reverse order (from last token to first).
cu_seqlens – Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like [0, len_seq1, len_seq1+len_seq2, …]. If provided, the input tensors are expected to be “packed” with a batch size of 1.
- Returns
o (jax.Array): The output tensor, with the same shape as q.
final_state (jax.Array): The final hidden state of the recurrence, which can be used as initial_state for a subsequent segment.
- Return type
A tuple containing