ejkernel.modules.operations.lightning_attention#
Lightning Attention module with automatic optimization.
This module implements Lightning Attention, a layer-aware attention mechanism that adapts computation based on the layer position in the network. It’s particularly efficient for deep transformers where different layers may benefit from different attention strategies.
Lightning Attention uses layer-specific optimizations and can maintain state across sequence processing for improved efficiency in recurrent-style computation.
- class ejkernel.modules.operations.lightning_attention.LightningAttention[source]#
Bases:
Kernel[LightningAttentionConfig,Array]Lightning Attention with custom optimization logic.
Implements a layer-aware attention mechanism optimized for deep transformer architectures. The attention computation adapts based on the layer index, allowing for more efficient processing in multi-layer networks.
- Features:
Layer-specific optimization strategies
Support for stateful computation with initial states
Bidirectional and reverse sequence processing
Variable-length sequence handling
Automatic platform selection (Triton/Pallas/XLA/CUDA)
- This is particularly useful for:
Very deep transformers where layer position matters
Models with recurrent-style attention patterns
Scenarios requiring different attention behavior per layer
- candidate_cfgs(inv: Invocation[LightningAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Lightning attention’s layer-aware design means performance may vary across layer depths. Candidates cover a range of block sizes.
- get_impl(cfg: LightningAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for lightning attention
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[LightningAttentionConfig, Array]) LightningAttentionConfig[source]#
Provide default configuration with block sizes.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration with conservative block sizes suitable for typical lightning attention workloads across various layer depths
- run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], layer_idx: int, num_layers: int, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: LightningAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#
Execute lightning attention with layer-specific optimization.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
layer_idx – Index of current layer in the model (0-indexed)
num_layers – Total number of layers in the model
softmax_scale – Optional scaling factor for attention scores
initial_state – Optional initial hidden state [batch, num_heads, head_dim, head_dim]
reverse – If True, process sequence in reverse order
cu_seqlens – Cumulative sequence lengths for variable-length sequences
return_state – If True, return tuple (output, final_state) instead of just output
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Attention output [batch, seq_len, num_heads, head_dim] If return_state=True: Tuple of (output, final_state) where final_state
is [batch, num_heads, head_dim, head_dim]
- Return type
If return_state=False
Note
The layer_idx and num_layers parameters enable layer-specific optimizations that can improve performance in deep networks.
- ejkernel.modules.operations.lightning_attention.lightning_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, layer_idx: int, num_layers: int, softmax_scale: float | None = None, reverse: bool = False, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.LightningAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#
Execute lightning attention with automatic optimization.
Lightning attention is an efficient attention mechanism that uses layer-specific optimizations for improved performance.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
layer_idx – Current layer index in the model
num_layers – Total number of layers in the model
softmax_scale – Scaling factor for attention
initial_state – Initial state for recurrent computation
reverse – Whether to process sequence in reverse
cu_seqlens – Cumulative sequence lengths for variable-length sequences
return_state – If True, return tuple (output, final_state) instead of just output
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Attention output with same shape as query If return_state=True: Tuple of (output, final_state)
- Return type
If return_state=False
Example
>>> >>> out = lightning_attention(query, key, value, layer_idx=5, num_layers=32) >>> >>> >>> out = lightning_attention(query, key, value, layer_idx=0, num_layers=24, softmax_scale=0.125) >>> >>> >>> out = lightning_attention(query, key, value, layer_idx=10, num_layers=32, cu_seqlens=cu_seqs) >>> >>> >>> out = lightning_attention(query, key, value, layer_idx=0, num_layers=24, platform="pallas")