ejkernel.kernels._triton.ring_attention._interface

Contents

ejkernel.kernels._triton.ring_attention._interface#

Ring Attention Implementation using Triton Flash Attention.

This module provides a ring attention implementation that wraps the Triton flash attention kernel for distributed execution across multiple GPU devices.

Key features: - Uses flash attention as the inner kernel for optimized GPU execution - Supports distributed ring topology via lax.ppermute - All flash attention features (causal, sliding window, dropout, etc.) - Full backward pass support with custom VJP

ejkernel.kernels._triton.ring_attention._interface.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, fused_backward: bool = False) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#

Ring attention using Triton flash attention kernels.

Distributes attention computation across devices using a ring topology, where each device holds its query partition and rotates key/value blocks through all devices, computing partial attention and combining results using online softmax.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim]

  • attention_mask – Optional attention mask

  • bias – Optional attention bias tensor

  • softmax_scale – Attention score scaling factor (default: 1/sqrt(head_dim))

  • dropout_prob – Dropout probability (default: 0.0)

  • causal – Whether to use causal masking (default: False)

  • dropout_seed – Random seed for dropout

  • sliding_window – Sliding window size. Can be: - int: symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric - None: no sliding window

  • logits_soft_cap – Soft cap value for attention logits (tanh-based capping)

  • axis_name – Name of the axis for ring communication (None for single device)

  • fwd_params – Forward pass block size parameters

  • bwd_params – Backward pass block size parameters

Returns

Output tensor [batch, seq_len_q, num_heads, head_dim]

Example

>>> # Basic causal ring attention
>>> output = ring_attention(q, k, v, causal=True, axis_name="sp")
>>> # With sliding window
>>> output = ring_attention(
...     q, k, v,
...     sliding_window=256,
...     causal=True,
...     axis_name="sp",
... )