ejkernel.kernels._pallas.tpu.ring_attention._interface#
Ring Attention interface using Splash Attention kernels.
This module provides the public API for ring attention on TPU using the splash attention implementation with ring communication topology.
- ejkernel.kernels._pallas.tpu.ring_attention._interface.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, fused_backward: bool = False) chex.Array[source]#
Computes ring attention using Splash Attention kernels on TPU.
This implementation uses JAX’s splash attention with ring communication topology for distributed attention computation across devices.
- Parameters
query – Query tensor [batch, q_len, num_heads, head_dim].
key – Key tensor [batch, kv_len, num_kv_heads, head_dim].
value – Value tensor [batch, kv_len, num_kv_heads, head_dim].
q_segment_ids – Optional query segment IDs [batch, q_len].
kv_segment_ids – Optional KV segment IDs [batch, kv_len].
softmax_aux – Optional attention sink logits (maps to sinks parameter).
bias – Optional attention bias (not supported in splash attention).
mask_builder – Optional custom mask builder function.
sliding_window – Sliding window size for local attention.
chunk_size – Chunk size for chunked causal attention.
causal – Whether to use causal masking.
logits_soft_cap – Soft cap for attention logits.
softmax_scale – Scaling factor for attention scores.
axis_name – Name of the ring communication axis.
fwd_params – Forward pass block size parameters.
bwd_params – Backward pass block size parameters.
fused_backward – Whether to use fused backward kernel.
- Returns
Attention output [batch, q_len, num_heads, head_dim].
- Raises
NotImplementedError – If bias is provided (not supported).