Source code for ejkernel.kernels._xla.blocksparse_attention._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Block-sparse attention interface for XLA fallback computation.

This module provides the public API for block-sparse attention that
handles packed multi-sequence inputs with segment IDs and positions.
Acts as a correctness fallback when specialized kernels are unavailable.
"""

from __future__ import annotations

import typing as tp

import jax
import jaxtyping
from beartype import beartype
from beartype.typing import Callable
from jax import numpy as jnp
from jaxtyping import Array, Bool, Float, Int

from ejkernel.ops import BwdParams, FwdParams

from ..._registry import Backend, Platform, kernel_registry

if tp.TYPE_CHECKING:
    from ejkernel.kernels._pallas.tpu.blocksparse_attention._masks import Mask
    from ejkernel.kernels._triton.blocksparse_attention._mask import SparseMask


def _normalize_segment_ids(ids: Int[Array, "..."] | None, *, which: str) -> Int[Array, "batch seqlen"] | None:
    if ids is None:
        return None
    ids = jnp.asarray(ids, jnp.int32)
    if ids.ndim == 2:
        return ids
    if ids.ndim == 3:
        return ids[:, 0, :]
    raise ValueError(f"{which}_segment_ids must be 2D or 3D, got shape {ids.shape}")


def _normalize_positions(
    pos: Int[Array, "..."] | None, *, batch: int, seqlen: int, fill: int
) -> Int[Array, "batch seqlen"]:
    if pos is None:
        return jnp.broadcast_to(jnp.arange(seqlen, dtype=jnp.int32)[None, :], (batch, seqlen))
    pos = jnp.asarray(pos, jnp.int32)
    if pos.shape != (batch, seqlen):
        raise ValueError(f"positions must have shape {(batch, seqlen)}, got {pos.shape}")
    return jnp.where(jnp.isnan(pos), fill, pos).astype(jnp.int32) if jnp.issubdtype(pos.dtype, jnp.floating) else pos


def _normalize_attention_mask(
    attention_mask: Bool[Array, "..."] | Int[Array, "..."] | None,
    *,
    batch: int,
    q_len: int,
    kv_len: int,
) -> Bool[Array, "batch q kv"] | None:
    if attention_mask is None:
        return None
    m = attention_mask
    if m.dtype != jnp.bool_:
        m = m != 0

    if m.ndim == 4:
        if m.shape[0] != batch or m.shape[2] != q_len or m.shape[3] != kv_len:
            raise ValueError(f"attention_mask must have shape (B, H/1, Q, K); got {m.shape}")
        # Head-specific masks cannot be encoded in a single (B,Q,K) in general; use head 0 for determinism.
        return m[:, 0, :, :]
    if m.ndim == 3:
        if m.shape != (batch, q_len, kv_len):
            raise ValueError(f"attention_mask must have shape (B, Q, K); got {m.shape}")
        return m
    if m.ndim == 2:
        if m.shape != (batch, kv_len):
            raise ValueError(f"2D attention_mask is treated as KV padding mask with shape (B, K); got {m.shape}")
        return jnp.broadcast_to(m[:, None, :], (batch, q_len, kv_len))

    raise ValueError(f"Unsupported attention_mask rank {m.ndim} with shape {m.shape}")


def _normalize_softmax_aux(
    softmax_aux: Float[Array, "..."] | None,
    *,
    num_heads: int,
    num_kv_heads: int,
    dtype: jnp.dtype,
) -> Float[Array, "num_heads num_sinks"] | None:
    if softmax_aux is None:
        return None
    aux = jnp.asarray(softmax_aux, dtype=dtype)
    if aux.ndim == 1:
        return jnp.broadcast_to(aux[None, :], (num_heads, aux.shape[0]))
    if aux.ndim == 2:
        if aux.shape[0] == num_heads:
            return aux
        if aux.shape[0] == num_kv_heads:
            reps = num_heads // num_kv_heads
            return jnp.repeat(aux, repeats=reps, axis=0)
        raise ValueError(
            f"softmax_aux first dim must be num_kv_heads ({num_kv_heads}) or num_heads ({num_heads}); got {aux.shape[0]}"
        )
    raise ValueError(f"softmax_aux must be 1D or 2D, got shape {aux.shape}")


[docs]@kernel_registry.register("blocksparse_attention", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def blocksparse_attention( query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], q_segment_ids: Int[Array, "batch seq_len"] | None = None, kv_segment_ids: Int[Array, "batch kv_len"] | None = None, q_positions: Int[Array, "batch seq_len"] | None = None, kv_positions: Int[Array, "batch kv_len"] | None = None, softmax_aux: Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None = None, attention_mask: Bool[Array, "batch num_heads_or_1 seq_len kv_len"] | Int[Array, "batch num_heads_or_1 seq_len kv_len"] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple["SparseMask"] | None = None, softmax_scale: float | None = None, fwd_params: FwdParams | None = None, bwd_params: BwdParams | None = None, mask_builder: Callable[[int, int, int, int, int], "Mask"] | Callable[[], "SparseMask"] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, ) -> Float[Array, "batch num_heads seq_len vhead_dim"]: """XLA fallback for block-sparse attention with packed (multi-sequence) support. This implementation is a correctness fallback: it materializes the token-level mask implied by segment IDs, positions, causal/sliding-window settings (and an optional attention_mask), then computes dense attention in JAX/XLA. """ del ( fused_backward, qkv_layouts, fwd_params, bwd_params, mask_builder, chunk_size, sequence_parallelism_mesh_axis_name, ) if bias is not None: raise NotImplementedError("Bias is not supported in blocksparse_attention (XLA fallback)") if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: raise ValueError("query/key/value must be rank-4 tensors (B, H, T, D)") batch, num_heads, q_len, head_dim = query.shape _b2, num_kv_heads, kv_len, _d2 = key.shape if _b2 != batch: raise ValueError(f"batch mismatch: query batch {batch}, key batch {_b2}") if value.shape[:3] != (batch, num_kv_heads, kv_len): raise ValueError(f"value must have shape (B, Hkv, K, Vd); got {value.shape}") if num_heads % num_kv_heads != 0: raise ValueError(f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})") if softmax_scale is None: softmax_scale = head_dim**-0.5 if sliding_window is None: window_left = window_right = None elif isinstance(sliding_window, int): window_left = window_right = int(sliding_window) else: window_left, window_right = int(sliding_window[0]), int(sliding_window[1]) q_ids = _normalize_segment_ids(q_segment_ids, which="q") kv_ids = _normalize_segment_ids(kv_segment_ids, which="kv") if kv_ids is None and q_ids is not None and kv_len == q_len: kv_ids = q_ids if q_ids is None and kv_ids is not None and kv_len == q_len: q_ids = kv_ids if q_ids is None: q_ids = jnp.ones((batch, q_len), dtype=jnp.int32) if kv_ids is None: kv_ids = jnp.ones((batch, kv_len), dtype=jnp.int32) q_pos = _normalize_positions(q_positions, batch=batch, seqlen=q_len, fill=-1) kv_pos = _normalize_positions(kv_positions, batch=batch, seqlen=kv_len, fill=jnp.iinfo(jnp.int32).max) q_valid = q_ids >= 0 kv_valid = kv_ids >= 0 mask = (q_ids[:, :, None] == kv_ids[:, None, :]) & q_valid[:, :, None] & kv_valid[:, None, :] if causal: mask = mask & (q_pos[:, :, None] >= kv_pos[:, None, :]) if window_left is not None or window_right is not None: wl = window_left if window_left is not None else jnp.iinfo(jnp.int32).max wr = window_right if window_right is not None else jnp.iinfo(jnp.int32).max mask = mask & (kv_pos[:, None, :] >= (q_pos[:, :, None] - wl)) & (kv_pos[:, None, :] <= (q_pos[:, :, None] + wr)) attn_mask = _normalize_attention_mask(attention_mask, batch=batch, q_len=q_len, kv_len=kv_len) if attn_mask is not None: mask = mask & attn_mask row_has_any = jnp.any(mask, axis=-1) reps = num_heads // num_kv_heads q = query.reshape(batch, num_kv_heads, reps, q_len, head_dim) k = key v = value logits = jnp.einsum("bhrqd,bhkd->bhrqk", q, k, optimize=True) * jnp.asarray(softmax_scale, dtype=query.dtype) if logits_soft_cap is not None: cap = jnp.asarray(logits_soft_cap, dtype=logits.dtype) logits = cap * jnp.tanh(logits / cap) neg = jnp.finfo(logits.dtype).min logits = jnp.where(mask[:, None, None, :, :], logits, neg) aux = _normalize_softmax_aux(softmax_aux, num_heads=num_heads, num_kv_heads=num_kv_heads, dtype=logits.dtype) if aux is not None: aux = aux.reshape(num_kv_heads, reps, aux.shape[-1]) sinks = jnp.broadcast_to(aux[None, :, :, None, :], (batch, num_kv_heads, reps, q_len, aux.shape[-1])) combined = jnp.concatenate([logits, sinks], axis=-1) probs = jax.nn.softmax(combined.astype(jnp.float32), axis=-1).astype(logits.dtype) weights = probs[..., :kv_len] else: weights = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype) weights = weights * row_has_any[:, None, None, :, None].astype(weights.dtype) out = jnp.einsum("bhrqk,bhkd->bhrqd", weights, v, optimize=True).reshape(batch, num_heads, q_len, value.shape[-1]) out = out * q_valid[:, None, :, None].astype(out.dtype) return out