ejkernel.kernels._pallas.tpu.ring_attention._interface

Contents

ejkernel.kernels._pallas.tpu.ring_attention._interface#

Ring Attention interface using Splash Attention kernels.

This module provides the public API for ring attention on TPU using the splash attention implementation with ring communication topology.

ejkernel.kernels._pallas.tpu.ring_attention._interface.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, fused_backward: bool = False) chex.Array[source]#

Computes ring attention using Splash Attention kernels on TPU.

This implementation uses JAX’s splash attention with ring communication topology for distributed attention computation across devices.

Parameters
  • query – Query tensor [batch, q_len, num_heads, head_dim].

  • key – Key tensor [batch, kv_len, num_kv_heads, head_dim].

  • value – Value tensor [batch, kv_len, num_kv_heads, head_dim].

  • q_segment_ids – Optional query segment IDs [batch, q_len].

  • kv_segment_ids – Optional KV segment IDs [batch, kv_len].

  • softmax_aux – Optional attention sink logits (maps to sinks parameter).

  • bias – Optional attention bias (not supported in splash attention).

  • mask_builder – Optional custom mask builder function.

  • sliding_window – Sliding window size for local attention.

  • chunk_size – Chunk size for chunked causal attention.

  • causal – Whether to use causal masking.

  • logits_soft_cap – Soft cap for attention logits.

  • softmax_scale – Scaling factor for attention scores.

  • axis_name – Name of the ring communication axis.

  • fwd_params – Forward pass block size parameters.

  • bwd_params – Backward pass block size parameters.

  • fused_backward – Whether to use fused backward kernel.

Returns

Attention output [batch, q_len, num_heads, head_dim].

Raises

NotImplementedError – If bias is provided (not supported).