ejkernel.kernels._triton.ring_attention._interface#
Ring Attention Implementation using Triton Flash Attention.
This module provides a ring attention implementation that wraps the Triton flash attention kernel for distributed execution across multiple GPU devices.
Key features: - Uses flash attention as the inner kernel for optimized GPU execution - Supports distributed ring topology via lax.ppermute - All flash attention features (causal, sliding window, dropout, etc.) - Full backward pass support with custom VJP
- ejkernel.kernels._triton.ring_attention._interface.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, fused_backward: bool = False) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#
Ring attention using Triton flash attention kernels.
Distributes attention computation across devices using a ring topology, where each device holds its query partition and rotates key/value blocks through all devices, computing partial attention and combining results using online softmax.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim]
attention_mask – Optional attention mask
bias – Optional attention bias tensor
softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))
dropout_prob – Dropout probability (default: 0.0)
causal – Whether to use causal masking (default: False)
dropout_seed – Random seed for dropout
sliding_window – Sliding window size. Can be: - int: symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric - None: no sliding window
logits_soft_cap – Soft cap value for attention logits (tanh-based capping)
axis_name – Name of the axis for ring communication (None for single device)
fwd_params – Forward pass block size parameters
bwd_params – Backward pass block size parameters
- Returns
Output tensor [batch, seq_len_q, num_heads, head_dim]
Example
>>> # Basic causal ring attention >>> output = ring_attention(q, k, v, causal=True, axis_name="sp")
>>> # With sliding window >>> output = ring_attention( ... q, k, v, ... sliding_window=256, ... causal=True, ... axis_name="sp", ... )