ejkernel.kernels._xla.blocksparse_attention._interface

ejkernel.kernels._xla.blocksparse_attention._interface#

Block-sparse attention interface for XLA fallback computation.

This module provides the public API for block-sparse attention that handles packed multi-sequence inputs with segment IDs and positions. Acts as a correctness fallback when specialized kernels are unavailable.

ejkernel.kernels._xla.blocksparse_attention._interface.blocksparse_attention(query: Float[Array, 'batch num_heads seq_len head_dim'], key: Float[Array, 'batch kv_num_heads kv_len head_dim'], value: Float[Array, 'batch kv_num_heads kv_len vhead_dim'], q_segment_ids: Int[Array, 'batch seq_len'] | None = None, kv_segment_ids: Int[Array, 'batch kv_len'] | None = None, q_positions: Int[Array, 'batch seq_len'] | None = None, kv_positions: Int[Array, 'batch kv_len'] | None = None, softmax_aux: Float[Array, 'num_sinks'] | None = None, bias: Float[Array, 'batch num_heads seq_len head_dim'] | None = None, attention_mask: Bool[Array, 'batch num_heads_or_1 seq_len kv_len'] | Int[Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple['SparseMask'] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], 'Mask'] | Callable[[], 'SparseMask'] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False) Float[Array, 'batch num_heads seq_len vhead_dim'][source]#

XLA fallback for block-sparse attention with packed (multi-sequence) support.

This implementation is a correctness fallback: it materializes the token-level mask implied by segment IDs, positions, causal/sliding-window settings (and an optional attention_mask), then computes dense attention in JAX/XLA.