ejkernel.kernels._xla.lightning_attn._interface

Contents

ejkernel.kernels._xla.lightning_attn._interface#

Lightning Attention interface with layer-dependent decay.

This module provides the public API for Lightning Attention using recurrent formulation with layer-depth-dependent decay rates for different temporal receptive fields across transformer layers.

ejkernel.kernels._xla.lightning_attn._interface.lightning_attn(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], layer_idx: int, num_layers: int, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#

Computes Lightning Attention using a recurrent, linear-time mechanism with JAX/XLA.

This function implements the Lightning Attention mechanism, a form of linear attention where the decay rate (g_gamma) is dynamically determined by the layer’s depth within the model. This allows for different temporal receptive fields across layers.

The computation is performed efficiently using a recurrent formulation, making it suitable for long sequences. It serves as a specialized wrapper around the generic recurrent function and supports both standard batch processing and packed variable-length inputs via cu_seqlens.

Parameters
  • query – The query tensor. Expected shape is (batch, seq_len, num_heads, head_dim) or (1, total_tokens, num_heads, head_dim) if cu_seqlens is used.

  • key – The key tensor. Must have the same shape as q.

  • value – The value tensor. Must have the same shape as q.

  • layer_idx – The 0-indexed index of the current layer, used to compute the layer-specific decay factor.

  • num_layers – The total number of layers in the model.

  • softmax_scale – A scaling factor applied to the query. If None, it defaults to 1 / sqrt(head_dim).

  • initial_state – The initial hidden state for the recurrence. Useful for chunked processing of long sequences.

  • reverse – If True, the sequence is processed in reverse order.

  • cu_seqlens – Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like [0, len_seq1, len_seq1+len_seq2, …]. If provided, the input tensors are expected to be “packed” with a batch size of 1.

Returns

  • o (jax.Array): The output tensor, with the same shape as q.

  • final_state (jax.Array): The final hidden state of the recurrence.

Return type

A tuple containing

Raises
  • ValueError – If cu_seqlens is provided and the batch size of q is not 1.

  • ValueError – If cu_seqlens is provided and the number of initial states does not match the number of sequences.

Examples

>>>
>>> q = jnp.ones((2, 100, 8, 64))
>>> k = jnp.ones((2, 100, 8, 64))
>>> v = jnp.ones((2, 100, 8, 64))
>>> output, final_state = lightning_attn(query, key, value, layer_idx=5, num_layers=24)
>>> output.shape
(2, 100, 8, 64)
>>>
>>>
>>>
>>>
>>> q = jnp.ones((1, 150, 8, 64))
>>> k = jnp.ones((1, 150, 8, 64))
>>> v = jnp.ones((1, 150, 8, 64))
>>> cu_seqlens = jnp.array([0, 50, 100, 150])
>>> output, final_state = lightning_attn(
...     query, key, value, layer_idx=10, num_layers=24, cu_seqlens=cu_seqlens
... )