Source code for ejkernel.modules.operations.scaled_dot_product_attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Scaled Dot Product Attention module with automatic optimization.

This module implements the standard scaled dot-product attention mechanism,
which is the fundamental building block of transformer architectures. It computes:

    Attention(Q,K,V) = softmax((Q @ K^T) / sqrt(d_k)) @ V

where Q, K, V are the query, key, and value matrices, and d_k is the key dimension.

This implementation provides:
    - Automatic platform selection (XLA, Triton, Pallas, CUDA)
    - Support for various attention patterns (causal, sliding window)
    - Variable-length sequence handling
    - Distributed execution via shard_map
    - Attention biasing and masking
    - Numerical stability through soft capping

Unlike FlashAttention which uses tiling for memory efficiency, this implementation
relies on platform-specific optimizations (e.g., XLA's attention primitive).
"""

from __future__ import annotations

import typing

import jax
from jax import shard_map
from jaxtyping import Array, Bool, Float, Int

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ejkernel.types.mask import MaskInfo

from ..base import detect_platform
from .configs import ScaledDotProductAttentionConfig


[docs]class ScaledDotProductAttention(Kernel[ScaledDotProductAttentionConfig, Array]): """ScaledDotProductAttention with custom optimization logic. Supports causal masking, dropout, sliding windows, and variable-length sequences. Features: - Automatic platform/backend selection (XLA Only ;0) - Configuration caching for consistent performance - Optional autotuning to find optimal implementation - Custom gradient support for efficient backpropagation - Support for variable-length sequences via cumulative sequence lengths - Sliding window attention for local attention patterns - Logits soft capping for numerical stability Example: >>> from ejkernel.modules import ScaledDotProductAttention, create_default_executor >>> >>> >>> executor = create_default_executor() >>> attn = ScaledDotProductAttention() >>> >>> >>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125) >>> >>> >>> output = executor( ... attn, query, key, value,... ... ) >>> >>> >>> output = executor(attn, query, key, value, sliding_window=(256, 256)) """ def __init__(self): """Initialize ScaledDotProductAttention module. Sets up the kernel with the operation identifier for registry lookup and configuration management. """ super().__init__(op_id="scaled_dot_product_attention")
[docs] def get_impl(self, cfg: ScaledDotProductAttentionConfig): """Get kernel implementation from registry based on configuration. Args: cfg: Configuration specifying platform and backend Returns: Callable kernel implementation Raises: ValueError: If no matching implementation is found """ return kernel_registry.get( algorithm="scaled_dot_product_attention", platform=detect_platform("scaled_dot_product_attention", cfg.platform), backend=cfg.backend, )
[docs] def run( self, query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch kv_len num_kv_heads head_dim"], attention_mask: Bool[Array, "batch num_heads_or_1 seq_len kv_len"] | None = None, bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, init_bias: typing.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, cum_seqlens_q: Int[Array, "batch"] | None = None, cum_seqlens_k: Int[Array, "batch"] | None = None, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, *, cfg: ScaledDotProductAttentionConfig, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: """Execute scaled_dot_product_attention with the given configuration. Args: query: Query tensor [batch, seq_len_q, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_heads, head_dim] value: Value tensor [batch, seq_len_k, num_heads, head_dim] attention_mask: Optional scaled_dot_product_attention mask (legacy, prefer bias) bias: Optional scaled_dot_product_attention bias tensor softmax_scale: Scaling factor for attention scores dropout_prob: Dropout probability for attention weights causal: Whether to apply causal masking dropout_seed: Random seed for dropout cum_seqlens_q: Cumulative sequence lengths for variable-length queries cum_seqlens_k: Cumulative sequence lengths for variable-length keys sliding_window: Window size for local attention logits_soft_cap: Optional soft cap value for logits softmax_aux: Optional attention sink logits cfg: Configuration object specifying platform/backend segment_ids: Segment IDs for grouped sequences (TPU-specific) block_sizes: Block sizes for kernel execution (TPU-specific) Returns: ScaledDotProductAttention output [batch, seq_len_q, num_heads, head_dim] """ if platform is not None: cfg = ScaledDotProductAttentionConfig( platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) return impl( query=query, key=key, value=value, attention_mask=attention_mask, bias=bias, softmax_scale=softmax_scale, init_bias=init_bias, sliding_window=sliding_window, causal=causal, cum_seqlens_q=cum_seqlens_q, cum_seqlens_k=cum_seqlens_k, )
[docs] def create_shard_map_wrapper( self, query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch kv_len num_kv_heads head_dim"], attention_mask: Bool[Array, "batch num_heads_or_1 seq_len kv_len"] | None = None, bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, cum_seqlens_q: Int[Array, "batch"] | None = None, cum_seqlens_k: Int[Array, "batch"] | None = None, *, mesh: jax.sharding.Mesh, in_specs: tuple[jax.sharding.PartitionSpec, ...], out_specs: jax.sharding.PartitionSpec, check_vma: bool = False, cfg: ScaledDotProductAttentionConfig, init_bias: typing.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, ): """Create a shard_map wrapper for distributed ScaledDotProductAttention execution. Enables efficient distributed execution of attention across multiple devices using JAX's shard_map functionality. This is particularly useful for model parallelism and handling very large attention computations. Args: query: Query tensor [batch, seq_len, num_q_heads, head_dim] key: Key tensor [batch, kv_len, num_kv_heads, head_dim] value: Value tensor [batch, kv_len, num_kv_heads, head_dim] attention_mask: Optional attention mask [batch, 1, seq_len, kv_len] bias: Optional attention bias [batch, num_heads, seq_len, kv_len] cum_seqlens_q: Cumulative sequence lengths for queries [batch] cum_seqlens_k: Cumulative sequence lengths for keys [batch] mesh: JAX mesh defining device topology for distributed execution in_specs: Partition specifications for each input tensor out_specs: Partition specification for output tensor check_vma: Whether to check for virtual memory access patterns cfg: Configuration object specifying platform/backend init_bias: Optional callable to initialize bias on-device softmax_scale: Scaling factor for attention scores causal: Whether to apply causal masking sliding_window: Window size for local attention platform: Optional platform override Returns: Tuple of (shard_map function, call args) where: - shard_map function: Callable for distributed execution - call args: Tuple of arguments to pass to the shard_map function Note: The shard_map wrapper handles device placement and communication automatically based on the provided mesh and partition specs. """ impl = self.get_impl(cfg) def _wrapped_sdpa( query, key, value, bias, cum_seqlens_q, cum_seqlens_k, attention_mask, ): return impl( query=query, key=key, value=value, attention_mask=attention_mask, bias=bias, cum_seqlens_q=cum_seqlens_q, cum_seqlens_k=cum_seqlens_k, init_bias=init_bias, softmax_scale=softmax_scale, causal=causal, sliding_window=sliding_window, ) call_args = ( query, key, value, bias, cum_seqlens_q, cum_seqlens_k, attention_mask, ) assert len(in_specs) == len(call_args), f"in_specs length {len(in_specs)} != call_args length {len(call_args)}" shard_map_fn = shard_map( _wrapped_sdpa, mesh=mesh, in_specs=in_specs, out_specs=out_specs, check_vma=check_vma, ) return shard_map_fn, call_args
[docs] def heuristic_cfg(self, inv: Invocation[ScaledDotProductAttentionConfig, Array]) -> ScaledDotProductAttentionConfig: """Provide default configuration based on invocation context. Args: inv: Invocation object with arguments and metadata Returns: Default configuration for platform/backend selection """ return ScaledDotProductAttentionConfig( platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[ScaledDotProductAttentionConfig, Array]): """Generate candidate configurations for autotuning. This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning. Args: inv: Invocation object with arguments and metadata Returns: Empty list - no candidates to autotune since XLA handles optimization Note: XLA's scaled_dot_product_attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark. """ return []
_executor: Executor[ScaledDotProductAttentionConfig, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy(allow_autotune=True, cache_miss_fallback="heuristics", validate_backward=True), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("sdpa"), ) )
[docs]def scaled_dot_product_attention( query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch kv_len num_kv_heads head_dim"], bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, cum_seqlens_q: Int[Array, "batch"] | None = None, cum_seqlens_k: Int[Array, "batch"] | None = None, /, *, mask_info: MaskInfo | None = None, init_bias: typing.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, mesh: jax.sharding.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, ) -> Float[Array, "batch seq_len_q num_heads head_dim"]: """Execute scaled dot product attention with automatic optimization. Convenience function that uses a default executor and flash attention module. Args: query: Query tensor [batch, seq_len, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_heads, head_dim] value: Value tensor [batch, seq_len_k, num_heads, head_dim] mask_info: Optional MaskInfo containing attention mask and/or segment IDs bias: Optional attention bias tensor softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) dropout_prob: Dropout probability for attention weights causal: Whether to apply causal masking dropout_seed: Random seed for dropout cum_seqlens_q: Cumulative sequence lengths for variable-length queries cum_seqlens_k: Cumulative sequence lengths for variable-length keys sliding_window: Window size for local attention (int or (left, right) tuple) logits_soft_cap: Optional soft cap value for logits mesh: JAX device mesh for shard_map execution (optional) in_specs: Input partition specs for shard_map (optional) out_specs: Output partition spec for shard_map (optional) Returns: ScaledDotProductAttention output with same shape as query Example: >>> >>> out = scaled_dot_product_attention(query, key, value, causal=True) >>> >>> >>> out = scaled_dot_product_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125) >>> >>> >>> out = scaled_dot_product_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k) """ attention_mask = None if mask_info is not None: attention_mask = mask_info.get_or_compute_attention_mask() method = None if mesh is not None and in_specs is not None and out_specs is not None: method = "shard_map" if mask_info is None: in_specs = (*in_specs, None) else: shardings = mask_info.get_shardings(False, mesh=mesh) in_specs = (*in_specs, shardings.attention_mask) return _executor( ScaledDotProductAttention(), query=query, key=key, value=value, attention_mask=attention_mask, bias=bias, softmax_scale=softmax_scale, init_bias=init_bias, sliding_window=sliding_window, causal=causal, cum_seqlens_q=cum_seqlens_q, cum_seqlens_k=cum_seqlens_k, platform=platform, method=method, mesh=mesh, in_specs=in_specs, out_specs=out_specs, )