ejkernel.kernels._xla.ring_attention._interface

Contents

ejkernel.kernels._xla.ring_attention._interface#

Ring attention interface for distributed sequence processing.

This module provides the public API for ring attention using blockwise transformers with KV cache rotation across devices. Supports chunked computation with custom VJP for gradient computation.

ejkernel.kernels._xla.ring_attention._interface.ring_attention(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], q_segment_ids: Int[Array, 'batch seq_len_q'] | None = None, kv_segment_ids: Int[Array, 'batch seq_len_k'] | None = None, softmax_aux: Float[Array, 'num_heads'] | Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, mask_builder: Callable[[int, int, int, int, int], Mask] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = False, logits_soft_cap: float | None = None, softmax_scale: float | None = None, axis_name: str | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, fused_backward: bool = False) Float[Array, 'batch seq_len_q num_heads head_dim'][source]#