ejkernel.kernels._triton.lightning_attn._interface#
Lightning Attention implementation using Triton kernels.
This module provides an efficient implementation of Lightning Attention, a linear attention mechanism that uses layer-dependent decay rates to adaptively adjust the temporal receptive field across different layers of a neural network.
Lightning Attention is designed for efficient sequence processing with O(N) complexity while allowing different layers to have different memory characteristics. Shallow layers can focus on local patterns with faster decay, while deeper layers can capture longer-range dependencies with slower decay.
Key innovation: The decay rate (g_gamma) is computed based on layer position:
g_gamma = -(8 / num_heads) * (1 - layer_idx / num_layers) * head_indices
This creates a progressive increase in temporal context as we move deeper into the network, mimicking the hierarchical feature learning in transformers but with linear complexity.
Features: - O(N) time complexity via recurrent formulation - Layer-adaptive decay rates for hierarchical learning - Support for variable-length sequences - GPU-optimized Triton kernels - Full gradient support via JAX autodiff
Example
>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.lightning_attn import lightning_attn
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 2048, 8, 64
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> output, final_state = lightning_attn(q, k, v, layer_idx=5, num_layers=12)
- ejkernel.kernels._triton.lightning_attn._interface.lightning_attn(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], layer_idx: int, num_layers: int, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#
Computes Lightning Attention using a recurrent, linear-time mechanism.
This function implements the Lightning Attention mechanism, a form of linear attention where the decay rate (g_gamma) is dynamically determined by the layer’s depth within the model. This allows for different temporal receptive fields across layers.
The computation is performed efficiently using a recurrent formulation, making it suitable for long sequences. It serves as a specialized wrapper around the generic recurrent function and supports both standard batch processing and packed variable-length inputs via cu_seqlens.
- Parameters
query – The query tensor. Expected shape is (batch, seq_len, num_heads, head_dim) or (1, total_tokens, num_heads, head_dim) if cu_seqlens is used.
key – The key tensor. Must have the same shape as q.
value – The value tensor. Must have the same shape as q.
layer_idx – The 0-indexed index of the current layer, used to compute the layer-specific decay factor.
num_layers – The total number of layers in the model.
softmax_scale – A scaling factor applied to the query. If None, it defaults to 1 / sqrt(head_dim).
initial_state – The initial hidden state for the recurrence. Useful for chunked processing of long sequences.
reverse – If True, the sequence is processed in reverse order.
cu_seqlens – Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like [0, len_seq1, len_seq1+len_seq2, …]. If provided, the input tensors are expected to be “packed” with a batch size of 1.
- Returns
o (jax.Array): The output tensor, with the same shape as q.
final_state (jax.Array): The final hidden state of the recurrence.
- Return type
A tuple containing
- Raises
ValueError – If cu_seqlens is provided and the batch size of q is not 1.
ValueError – If cu_seqlens is provided and the number of initial states does not match the number of sequences.