ejkernel.modules.operations.lightning_attention#

Lightning Attention module with automatic optimization.

This module implements Lightning Attention, a layer-aware attention mechanism that adapts computation based on the layer position in the network. It’s particularly efficient for deep transformers where different layers may benefit from different attention strategies.

Lightning Attention uses layer-specific optimizations and can maintain state across sequence processing for improved efficiency in recurrent-style computation.

class ejkernel.modules.operations.lightning_attention.LightningAttention[source]#

Bases: Kernel[LightningAttentionConfig, Array]

Lightning Attention with custom optimization logic.

Implements a layer-aware attention mechanism optimized for deep transformer architectures. The attention computation adapts based on the layer index, allowing for more efficient processing in multi-layer networks.

Features:
  • Layer-specific optimization strategies

  • Support for stateful computation with initial states

  • Bidirectional and reverse sequence processing

  • Variable-length sequence handling

  • Automatic platform selection (Triton/Pallas/XLA/CUDA)

This is particularly useful for:
  • Very deep transformers where layer position matters

  • Models with recurrent-style attention patterns

  • Scenarios requiring different attention behavior per layer

candidate_cfgs(inv: Invocation[LightningAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Lightning attention’s layer-aware design means performance may vary across layer depths. Candidates cover a range of block sizes.

get_impl(cfg: LightningAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for lightning attention

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[LightningAttentionConfig, Array]) LightningAttentionConfig[source]#

Provide default configuration with block sizes.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with conservative block sizes suitable for typical lightning attention workloads across various layer depths

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], layer_idx: int, num_layers: int, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: LightningAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute lightning attention with layer-specific optimization.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • layer_idx – Index of current layer in the model (0-indexed)

  • num_layers – Total number of layers in the model

  • softmax_scale – Optional scaling factor for attention scores

  • initial_state – Optional initial hidden state [batch, num_heads, head_dim, head_dim]

  • reverse – If True, process sequence in reverse order

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Attention output [batch, seq_len, num_heads, head_dim] If return_state=True: Tuple of (output, final_state) where final_state

is [batch, num_heads, head_dim, head_dim]

Return type

If return_state=False

Note

The layer_idx and num_layers parameters enable layer-specific optimizations that can improve performance in deep networks.

ejkernel.modules.operations.lightning_attention.lightning_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, layer_idx: int, num_layers: int, softmax_scale: float | None = None, reverse: bool = False, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.LightningAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute lightning attention with automatic optimization.

Lightning attention is an efficient attention mechanism that uses layer-specific optimizations for improved performance.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • layer_idx – Current layer index in the model

  • num_layers – Total number of layers in the model

  • softmax_scale – Scaling factor for attention

  • initial_state – Initial state for recurrent computation

  • reverse – Whether to process sequence in reverse

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Attention output with same shape as query If return_state=True: Tuple of (output, final_state)

Return type

If return_state=False

Example

>>>
>>> out = lightning_attention(query, key, value, layer_idx=5, num_layers=32)
>>>
>>>
>>> out = lightning_attention(query, key, value, layer_idx=0, num_layers=24, softmax_scale=0.125)
>>>
>>>
>>> out = lightning_attention(query, key, value, layer_idx=10, num_layers=32, cu_seqlens=cu_seqs)
>>>
>>>
>>> out = lightning_attention(query, key, value, layer_idx=0, num_layers=24, platform="pallas")