Source code for ejkernel.modules.operations.blocksparse_attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Block-sparse attention module with automatic optimization.

This module implements block-sparse attention, which applies attention only to
predefined blocks of the attention matrix, significantly reducing computational
cost for long sequences while maintaining important attention patterns.

The block-sparse pattern is defined by a mask builder function that determines
which blocks should be computed. This is particularly useful for document-level
attention, local attention patterns, and sparse attention architectures.
"""

from __future__ import annotations

import math
import os
import typing

from jax import numpy as jnp
from jax import shard_map
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, Float, Int

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    BwdParams,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    FwdParams,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ejkernel.types.mask import MaskInfo

from ..base import detect_platform
from .configs import BlockSparseAttentionConfig

if typing.TYPE_CHECKING:
    from ejkernel.kernels._pallas.tpu.blocksparse_attention._masks import Mask
    from ejkernel.kernels._triton.blocksparse_attention._mask import SparseMask


[docs]class BlockSparseAttention(Kernel[BlockSparseAttentionConfig, Array]): """Block-sparse attention kernel with custom optimization logic. Implements attention computation over sparse block patterns, computing attention only for specified blocks rather than the full attention matrix. This reduces computational complexity from O(N^2) to O(N * B) where B is the average number of blocks per row. Features: - Configurable sparse block patterns via mask builder - Support for causal masking and sliding windows - Automatic platform/backend selection - Optional autotuning for optimal block sizes - Gradient support for training with custom VJP - Logit soft capping with tanh activation for numerical stability (Gemma-2 style) - Separate forward/backward block sizes for performance tuning The mask builder function defines which blocks to compute, enabling patterns like: - Local attention (nearby tokens only) - Global + local (attending to special tokens + local context) - Strided patterns (every nth block) - Custom patterns based on document structure Example: >>> from ejkernel.modules.operations import BlockSparseAttention >>> from ejkernel.modules import create_default_executor >>> >>> executor = create_default_executor() >>> attn = BlockSparseAttention() >>> >>> >>> def local_mask(q_idx, k_idx, q_size, k_size, window): ... ... pass >>> >>> output = executor( ... attn, ... query, key, value, ... mask_builder=local_mask, ... chunk_size=128 ... ) """ def __init__(self): """Initialize BlockSparseAttention module.""" super().__init__(op_id="blocksparse_attention")
[docs] def create_shard_map_wrapper( self, query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], softmax_aux: Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None = None, q_segment_ids: Int[Array, "batch seq_len"] | None = None, kv_segment_ids: Int[Array, "batch kv_len"] | None = None, q_positions: Int[Array, "batch seq_len"] | None = None, kv_positions: Int[Array, "batch kv_len"] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple["SparseMask"] | None = None, softmax_scale: float | None = None, mask_builder: typing.Callable[[int, int, int, int, int], "Mask"] | typing.Callable[[], "SparseMask"] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: BlockSparseAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec, ...] | None = None, out_specs: PartitionSpec | None = None, check_vma: bool = False, ): """Create a shard_map wrapper specifically for blocksparse attention. Args: mesh: JAX device mesh in_specs: Input partition specs (must match length of tensor args) out_specs: Output partition spec query, key, value: Input tensors to be sharded All other args: Blocksparse attention parameters Returns: Tuple of (shard_map_fn, call_args) """ assert mesh is not None, "mesh must be provided for shard_map execution" assert in_specs is not None, "in_specs must be provided for shard_map execution" assert out_specs is not None, "out_specs must be provided for shard_map execution" def _wrapped_blocksparse_attn( query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], softmax_aux: Float[Array, "num_sinks"] | None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None, q_segment_ids: Int[Array, "batch seq_len"] | None, kv_segment_ids: Int[Array, "batch kv_len"] | None, q_positions: Int[Array, "batch seq_len"] | None, kv_positions: Int[Array, "batch kv_len"] | None, ) -> Float[Array, "batch seq_len num_heads head_dim"]: return self.run( query=query, key=key, value=value, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, q_positions=q_positions, kv_positions=kv_positions, softmax_aux=softmax_aux, bias=bias, sequence_parallelism_mesh_axis_name=sequence_parallelism_mesh_axis_name, logits_soft_cap=logits_soft_cap, qkv_layouts=qkv_layouts, softmax_scale=softmax_scale, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, fused_backward=fused_backward, platform=platform, cfg=cfg or self.heuristic_cfg(None), ) call_args = ( query, key, value, softmax_aux, bias, q_segment_ids, kv_segment_ids, q_positions, kv_positions, ) assert len(in_specs) == len(call_args), f"in_specs length {len(in_specs)} != call_args length {len(call_args)}" shard_map_fn = shard_map( _wrapped_blocksparse_attn, mesh=mesh, in_specs=in_specs, out_specs=out_specs, check_vma=check_vma, ) return shard_map_fn, call_args
[docs] def get_impl(self, cfg: BlockSparseAttentionConfig): """Get kernel implementation from registry based on configuration. Args: cfg: Configuration specifying platform and backend preferences Returns: Callable kernel implementation for block-sparse attention Raises: ValueError: If no matching implementation is found for the configuration """ return kernel_registry.get( algorithm="blocksparse_attention", platform=detect_platform("blocksparse_attention", cfg.platform), backend=cfg.backend, )
[docs] def run( self, query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], softmax_aux: Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None = None, q_segment_ids: Int[Array, "batch seq_len"] | None = None, kv_segment_ids: Int[Array, "batch kv_len"] | None = None, q_positions: Int[Array, "batch seq_len"] | None = None, kv_positions: Int[Array, "batch kv_len"] | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple["SparseMask"] | None = None, softmax_scale: float | None = None, mask_builder: typing.Callable[[int, int, int, int, int], "Mask"] | typing.Callable[[], "SparseMask"] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, *, cfg: BlockSparseAttentionConfig, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: """Execute block-sparse attention with the given configuration. Args: query: Query tensor [batch, num_heads, seq_len, head_dim] key: Key tensor [batch, kv_num_heads, kv_len, head_dim] value: Value tensor [batch, kv_num_heads, kv_len, vhead_dim] q_segment_ids: Segment IDs for queries to handle multiple sequences [batch, seq_len] kv_segment_ids: Segment IDs for keys/values [batch, kv_len] softmax_aux: Auxiliary values added to attention scores (e.g., for attention sinks) logits_soft_cap: Optional soft cap value to bound attention logits softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) mask_builder: Function that builds the sparse mask pattern. Takes (q_idx, k_idx, q_size, k_size, window_size) and returns a Mask object sliding_window: Window size for local attention, int for symmetric or (left, right) tuple chunk_size: Overall chunk size (alternative to separate query/key chunk sizes) causal: Whether to apply causal masking (default: True) fused_backward: Use fused backward pass for improved gradient computation platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Configuration object specifying platform/backend and kernel parameters Returns: Attention output tensor [batch, seq_len_q, num_heads, head_dim] Note: The mask_builder function is critical for defining sparsity patterns. It should return a mask indicating which blocks to compute. """ if platform is not None: cfg = BlockSparseAttentionConfig( fwd_params=cfg.fwd_params, bwd_params=cfg.bwd_params, platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) return impl( query=query, key=key, value=value, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, q_positions=q_positions, kv_positions=kv_positions, softmax_aux=softmax_aux, logits_soft_cap=logits_soft_cap, bias=bias, sequence_parallelism_mesh_axis_name=sequence_parallelism_mesh_axis_name, qkv_layouts=qkv_layouts, fwd_params=cfg.fwd_params, bwd_params=cfg.bwd_params, softmax_scale=softmax_scale, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, fused_backward=fused_backward, )
[docs] def heuristic_cfg_gpu(self, inv: Invocation[BlockSparseAttentionConfig, Array]) -> BlockSparseAttentionConfig: """Provide default configuration based on invocation context. Selects optimal block sizes based on sequence length and head dimension. Args: inv: Invocation object with arguments and metadata Returns: Default configuration with block sizes """ return BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=64, kv_blocksize=64, num_warps=4, num_stages=2, ), bwd_params=BwdParams( q_blocksize=32, kv_blocksize=32, num_warps=4, num_stages=2, ), platform="triton", backend="gpu", )
[docs] def heuristic_cfg_tpu(self, inv: Invocation[BlockSparseAttentionConfig, Array]) -> BlockSparseAttentionConfig: """Provide default configuration based on invocation context. Selects optimal block sizes based on sequence length and head dimension. Args: inv: Invocation object with arguments and metadata Returns: Default configuration with block sizes """ return BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=128, kv_blocksize=128, num_warps=None, num_stages=None, ), bwd_params=BwdParams( q_blocksize=128, kv_blocksize=128, num_warps=None, num_stages=None, ), platform="pallas", backend="tpu", )
[docs] def heuristic_cfg(self, inv: Invocation[BlockSparseAttentionConfig, Array]) -> BlockSparseAttentionConfig: """Provide default configuration based on invocation context. Selects optimal block sizes based on sequence length and head dimension. Args: inv: Invocation object with arguments and metadata Returns: Default configuration with block sizes """ return BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=128, kv_blocksize=128, num_warps=None, num_stages=None, ), bwd_params=BwdParams( q_blocksize=128, kv_blocksize=128, num_warps=None, num_stages=None, ), platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[BlockSparseAttentionConfig, Array]): """Generate candidate configurations for autotuning. Creates multiple block size configurations for benchmarking to find the optimal tiling parameters for the given input shapes. Args: inv: Invocation object with arguments and metadata Returns: Iterable of candidate configurations to test during autotuning Note: The autotuning system will benchmark each candidate and select the fastest one for the given input configuration. """ block_configs = [(256, 256), (512, 512)] candidates = [] for q_block, kv_block in block_configs: candidates.append( BlockSparseAttentionConfig( q_blocksize=q_block, kv_blocksize=kv_block, bwd_q_blocksize=q_block * 2, bwd_kv_blocksize=kv_block * 2, num_warps=4, num_stages=2, platform="auto", backend="any", ) ) return candidates
[docs] def candidate_cfgs_gpu(self, inv: Invocation[BlockSparseAttentionConfig, Array]): """Generate GPU-optimized candidate configurations for autotuning (Triton). Heuristics: - q/kv blocks in {32, 64, 128, 256} depending on head_dim - If sliding_window is set, favor kv blocks ≲ window size (rounded) - num_warps: 2-8 based on head_dim and block sizes - num_stages: 2-4 (bigger when kv block is large) - Backward block sizes smaller to reduce register pressure """ q = inv.kwargs["query"] k = inv.kwargs["key"] head_dim = int(q.shape[-1]) q_len = int(q.shape[-2]) k_len = int(k.shape[-2]) dtype = q.dtype sliding_window = inv.kwargs.get("sliding_window", None) causal = bool(inv.kwargs.get("causal", True)) def window_total(sw): if sw is None: return None if isinstance(sw, int): right = 0 if causal else sw return sw + right + 1 wl, wr = sw return wl + wr + 1 win = window_total(sliding_window) smem_limit = int(os.getenv("EJKERNEL_TRITON_SMEM_LIMIT", str(99 * 1024))) block_headdim = max(1 << max(4, math.ceil(math.log2(max(1, head_dim)))), 16) elem_bytes = 2 if dtype in (jnp.float16, jnp.bfloat16) else 4 def smem_est_bytes(qb: int, kb: int, num_stages: int) -> int: kv_bytes = 2 * kb * block_headdim * elem_bytes q_bytes = int(0.5 * qb * block_headdim * elem_bytes) base = kv_bytes + q_bytes stage_factor = 1.0 + 0.5 * max(0, num_stages - 2) fudge = 3.0 return int(base * stage_factor * fudge) if head_dim <= 64: q_opts = [32, 64, 128] elif head_dim <= 128: q_opts = [32, 64, 128] elif head_dim <= 192: q_opts = [32, 64, 128] else: q_opts = [32, 64, 128] base_kv = [32, 64, 128, 256] if win is not None: target = max(32, min(256, 1 << (int(math.log2(max(32, win))) if win > 0 else 5))) kv_opts = sorted(set([32, 64, min(128, target), min(256, target)])) else: kv_opts = base_kv if k_len < 128: kv_opts = [x for x in kv_opts if x <= 128] or [64, 128] if q_len < 128: q_opts = [x for x in q_opts if x <= 128] or [64, 128] def pick_warps_stages(qb: int, kb: int, dh: int) -> tuple[int, int]: if dh <= 64: warps = 2 if max(qb, kb) <= 64 else 4 elif dh <= 128: warps = 4 if max(qb, kb) <= 128 else 8 else: warps = 8 if max(qb, kb) >= 128 else 4 if kb >= 256: stages = 3 elif kb >= 128: stages = 2 else: stages = 2 return warps, stages def bwd_block(x: int, cap: int = 128) -> int: return max(32, min(cap, x // 2 if x >= 64 else x)) hv_pairs = [] preferred = [(32, 64), (64, 64), (64, 128), (128, 64)] if win is not None: preferred.insert(0, (32, min(128, max(64, win)))) for qb, kb in preferred: if qb in q_opts and kb in kv_opts: hv_pairs.append((qb, kb)) grid_pairs = [] for qb in q_opts: for kb in kv_opts: if (qb, kb) not in hv_pairs: grid_pairs.append((qb, kb)) max_candidates = 18 pairs = [] seen = set() for qb, kb in hv_pairs + grid_pairs: if (qb, kb) in seen: continue w, s = pick_warps_stages(qb, kb, head_dim) if smem_est_bytes(qb, kb, s) <= smem_limit: seen.add((qb, kb)) pairs.append((qb, kb, w, s)) if len(pairs) >= max_candidates: break if not pairs: qb, kb = 32, 64 w, s = pick_warps_stages(qb, kb, head_dim) pairs = [(qb, kb, w, s)] configs: list[BlockSparseAttentionConfig] = [] for qb, kb, w, s in pairs: configs.append( BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=qb, kv_blocksize=kb, num_warps=w, num_stages=s, ), bwd_params=BwdParams( q_blocksize=bwd_block(qb), kv_blocksize=bwd_block(kb), num_warps=w, num_stages=max(2, s - 0), ), platform="triton", backend="gpu", ) ) return configs
[docs] def candidate_cfgs_tpu(self, inv: Invocation[BlockSparseAttentionConfig, Array]): """Generate TPU-optimized candidate configurations for autotuning (Pallas).""" q = inv.kwargs["query"] k = inv.kwargs["key"] q_len = int(q.shape[-2]) k_len = int(k.shape[-2]) sliding_window = inv.kwargs.get("sliding_window", None) causal = bool(inv.kwargs.get("causal", True)) def win_span(sw): if sw is None: return None if isinstance(sw, int): right = 0 if causal else sw return sw + right + 1 wl, wr = sw return wl + wr + 1 def nearest_128_from_set(x: int, allowed=(128, 256, 512, 1024)) -> int: return min(allowed, key=lambda v: (abs(v - x), v)) allowed = (128, 256, 512, 1024) win = win_span(sliding_window) q_opts = [b for b in allowed if b <= max(128, q_len)] or [128] kv_opts = [b for b in allowed if b <= max(128, k_len)] or [128] if win is not None: t = nearest_128_from_set(max(128, min(1024, win)), allowed) kv_opts = sorted(set([*kv_opts, t, min(1024, 2 * t)])) q_opts = sorted(set(q_opts)) kv_opts = sorted(set(kv_opts)) def bwd_tile(x: int) -> int: return 128 if x <= 256 else 256 hv_pairs: list[tuple[int, int]] = [] if win is not None: t = nearest_128_from_set(max(128, min(1024, win)), allowed) for qb in (128, 256): if qb in q_opts and t in kv_opts: hv_pairs.append((qb, t)) if 2 * t <= 1024 and (128 in q_opts) and (2 * t in kv_opts): hv_pairs.append((128, 2 * t)) hv_pairs += [(128, 128), (128, 256), (256, 256), (256, 512)] hv_pairs = [(qb, kb) for (qb, kb) in hv_pairs if qb in q_opts and kb in kv_opts] grid_pairs = [] for qb in q_opts: for kb in kv_opts: if (qb, kb) not in hv_pairs: grid_pairs.append((qb, kb)) max_candidates = 16 pairs: list[tuple[int, int]] = [] seen = set() for qb, kb in hv_pairs + grid_pairs: if (qb, kb) in seen: continue seen.add((qb, kb)) pairs.append((qb, kb)) if len(pairs) >= max_candidates: break configs: list[BlockSparseAttentionConfig] = [] for qb, kb in pairs: configs.append( BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=qb, kv_blocksize=kb, num_warps=None, num_stages=None, ), bwd_params=BwdParams( q_blocksize=bwd_tile(qb), kv_blocksize=bwd_tile(kb), num_warps=None, num_stages=None, ), platform="pallas", backend="tpu", ) ) return configs
[docs] def candidate_cfgs_xla(self, inv: Invocation[BlockSparseAttentionConfig, Array]): q = inv.kwargs["query"] k = inv.kwargs["key"] q_len = int(q.shape[-2]) k_len = int(k.shape[-2]) sliding_window = inv.kwargs.get("sliding_window", None) causal = bool(inv.kwargs.get("causal", True)) def win_span(sw): if sw is None: return None if isinstance(sw, int): right = 0 if causal else sw return sw + right + 1 wl, wr = sw return wl + wr + 1 def nearest_128_from_set(x: int, allowed=(128, 256, 512, 1024)) -> int: return min(allowed, key=lambda v: (abs(v - x), v)) allowed = (128, 256, 512, 1024) win = win_span(sliding_window) q_opts = [b for b in allowed if b <= max(128, q_len)] or [128] kv_opts = [b for b in allowed if b <= max(128, k_len)] or [128] if win is not None: t = nearest_128_from_set(max(128, min(1024, win)), allowed) kv_opts = sorted(set([*kv_opts, t, min(1024, 2 * t)])) q_opts = sorted(set(q_opts)) kv_opts = sorted(set(kv_opts)) def bwd_tile(x: int) -> int: return 128 if x <= 256 else 256 hv_pairs: list[tuple[int, int]] = [] if win is not None: t = nearest_128_from_set(max(128, min(1024, win)), allowed) for qb in (128, 256): if qb in q_opts and t in kv_opts: hv_pairs.append((qb, t)) hv_pairs += [(128, 128), (128, 256), (256, 256), (256, 128)] hv_pairs = [(qb, kb) for (qb, kb) in hv_pairs if qb in q_opts and kb in kv_opts] grid_pairs = [] for qb in q_opts: for kb in kv_opts: if (qb, kb) not in hv_pairs: grid_pairs.append((qb, kb)) max_candidates = 12 pairs: list[tuple[int, int]] = [] seen = set() for qb, kb in hv_pairs + grid_pairs: if (qb, kb) in seen: continue seen.add((qb, kb)) pairs.append((qb, kb)) if len(pairs) >= max_candidates: break configs: list[BlockSparseAttentionConfig] = [] for qb, kb in pairs: configs.append( BlockSparseAttentionConfig( fwd_params=FwdParams( q_blocksize=qb, kv_blocksize=kb, num_warps=None, num_stages=None, ), bwd_params=BwdParams( q_blocksize=bwd_tile(qb), kv_blocksize=bwd_tile(kb), num_warps=None, num_stages=None, ), platform="xla", backend="any", ) ) return configs
candidate_cfgs_shard_map_gpu = candidate_cfgs_gpu candidate_cfgs_shard_map_tpu = candidate_cfgs_tpu candidate_cfgs_shard_map_xla = candidate_cfgs_xla
_executor: Executor[BlockSparseAttentionConfig, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy( allow_autotune=True, cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"), validate_backward=True, ), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("blocksparse", cfg_type=BlockSparseAttentionConfig), ), )
[docs]def blocksparse_attention( query: Float[Array, "batch num_heads seq_len head_dim"], key: Float[Array, "batch kv_num_heads kv_len head_dim"], value: Float[Array, "batch kv_num_heads kv_len vhead_dim"], softmax_aux: Float[Array, "num_sinks"] | None = None, bias: Float[Array, "batch num_heads seq_len head_dim"] | None = None, /, *, mask_info: MaskInfo | None = None, sequence_parallelism_mesh_axis_name: str | None = None, logits_soft_cap: float | None = None, qkv_layouts: tuple["SparseMask"] | None = None, softmax_scale: float | None = None, mask_builder: typing.Callable[[int, int, int, int, int], "Mask"] | typing.Callable[[], "SparseMask"] | None = None, sliding_window: int | tuple[int, int] | None = None, chunk_size: int | None = None, causal: bool = True, fused_backward: bool = False, purify: bool = False, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: BlockSparseAttentionConfig | None = None, mesh: Mesh | None = None, in_specs: tuple[PartitionSpec | None, ...] | None = None, out_specs: PartitionSpec | None = None, ) -> Float[Array, "batch kv_num_heads kv_len vhead_dim"]: """Execute block-sparse attention with automatic optimization. Performs efficient attention computation over sparse block patterns, significantly reducing memory and computation compared to dense attention while maintaining flexibility through custom mask builders. Args: query: Query tensor [batch, num_heads, seq_len, head_dim] key: Key tensor [batch, kv_num_heads, kv_len, head_dim] value: Value tensor [batch, kv_num_heads, kv_len, vhead_dim] mask_info: Optional MaskInfo containing attention mask, segment IDs, and position indices q_positions: Optional query position indices [batch, seq_len] for positional embeddings. If None and mask_info is provided, will use positions from mask_info. kv_positions: Optional key-value position indices [batch, kv_len] for positional embeddings. If None and mask_info is provided, will use positions from mask_info. softmax_aux: Optional auxiliary attention values (e.g., attention sinks) logits_soft_cap: Optional soft capping for attention logits query_chunk_size: Query chunk size for block tiling (default: 128) key_chunk_size: Key chunk size for block tiling (default: 128) softmax_scale: Attention score scaling factor (default: 1/sqrt(head_dim)) mask_builder: Callable defining sparse pattern. Signature: (q_idx, k_idx, q_size, k_size, window) -> Mask sliding_window: Window size for local attention (int or (left, right) tuple) chunk_size: Alternative to separate query_chunk_size/key_chunk_size causal: Apply causal masking (default: True) fused_backward: Use fused backward pass (default: False) platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Optional configuration override mesh: JAX device mesh for shard_map execution (optional) in_specs: Input partition specs for shard_map (optional) out_specs: Output partition spec for shard_map (optional) Returns: Attention output [batch, kv_num_heads, kv_len, vhead_dim] Example: >>> from ejkernel.modules.operations import blocksparse_attention >>> >>> >>> output = blocksparse_attention(query, key, value, causal=True) >>> >>> >>> def local_plus_global(q_idx, k_idx, q_size, k_size, window): ... ... return create_local_global_mask(q_idx, k_idx, window) >>> >>> output = blocksparse_attention( ... query, key, value, ... mask_builder=local_plus_global, ... sliding_window=256 ... ) >>> >>> >>> output = blocksparse_attention( ... query, key, value, ... platform="triton" ... ) Note: Block-sparse attention is particularly effective for: - Long document processing where full attention is prohibitive - Architectures with specific attention patterns (e.g., Longformer) - Scenarios where custom sparsity patterns are needed """ q_segment_ids = None kv_segment_ids = None q_positions = None kv_positions = None q_mask = None if mask_info is not None: q_segment_ids, kv_segment_ids = mask_info.get_or_compute_segment_ids() q_ids_for_mask = q_segment_ids if q_ids_for_mask is not None and q_ids_for_mask.ndim == 3: q_ids_for_mask = q_ids_for_mask[:, 0, :] if q_ids_for_mask is not None: q_mask = q_ids_for_mask >= 0 if q_positions is None or kv_positions is None: mask_q_pos, mask_kv_pos = mask_info.get_or_compute_positions() if q_positions is None: q_positions = mask_q_pos if kv_positions is None: kv_positions = mask_kv_pos method = None if mesh is not None and in_specs is not None and out_specs is not None: method = "shard_map" if mask_info is None: in_specs = (*in_specs, None, None, None, None) else: shardings = mask_info.get_shardings(False, mesh=mesh) in_specs = ( *in_specs, shardings.q_segment_ids, shardings.kv_segment_ids, shardings.q_positions, shardings.kv_positions, ) out = _executor( BlockSparseAttention(), query=query, key=key, value=value, q_segment_ids=q_segment_ids, kv_segment_ids=kv_segment_ids, q_positions=q_positions, kv_positions=kv_positions, softmax_aux=softmax_aux, logits_soft_cap=logits_soft_cap, bias=bias, sequence_parallelism_mesh_axis_name=sequence_parallelism_mesh_axis_name, qkv_layouts=qkv_layouts, softmax_scale=softmax_scale, mask_builder=mask_builder, sliding_window=sliding_window, chunk_size=chunk_size, causal=causal, fused_backward=fused_backward, platform=platform, method=method, mesh=mesh, in_specs=in_specs, out_specs=out_specs, _cfg=cfg, ) if q_mask is not None and purify: out = out * q_mask[:, None, :, None].astype(out.dtype) return out