Source code for ejkernel.modules.operations.attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Standard multi-head attention module with automatic optimization.

This module implements standard multi-head attention (MHA) with XLA-optimized kernels.
It provides a flexible interface supporting various attention patterns including causal
masking, dropout, sliding windows, and variable-length sequences.

Unlike FlashAttention which uses tiling for memory efficiency, this implementation
leverages XLA's compiler optimizations for straightforward attention computation.
"""

from __future__ import annotations

import typing as tp

from jax import numpy as jnp
from jaxtyping import Array, Bool, DTypeLike, Float, PRNGKeyArray

from ejkernel.kernels._registry import kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ejkernel.types.mask import MaskInfo

from ..base import detect_platform
from .configs import AttentionConfig


[docs]class Attention(Kernel[AttentionConfig, tuple[Array, Array]]): """Attention with custom optimization logic. Supports causal masking, dropout, sliding windows, and variable-length sequences. Features: - Automatic platform/backend selection (XLA Only ;0) - Configuration caching for consistent performance - Optional autotuning to find optimal implementation - Custom gradient support for efficient backpropagation - Support for variable-length sequences via cumulative sequence lengths - Sliding window attention for local attention patterns - Logits soft capping for numerical stability Example: >>> from ejkernel.modules import Attention, create_default_executor >>> >>> >>> executor = create_default_executor() >>> attn = Attention() >>> >>> >>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125) >>> >>> >>> output = executor( ... attn, query, key, value,... ... ) >>> >>> >>> output = executor(attn, query, key, value, sliding_window=(256, 256)) """ def __init__(self): """Initialize Attention module.""" super().__init__(op_id="attention")
[docs] def get_impl(self, cfg: AttentionConfig): """Get kernel implementation from registry based on configuration. Args: cfg: Configuration specifying platform and backend Returns: Callable kernel implementation Raises: ValueError: If no matching implementation is found """ return kernel_registry.get( algorithm="attention", platform=detect_platform("attention", cfg.platform), backend=cfg.backend, )
[docs] def run( self, query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch seq_len num_q_heads vhead_dim"], attention_mask: Bool[Array, "batch num_heads_or_1 seq_len kv_len"] | None = None, bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, init_bias: tp.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, deterministic: bool = True, dropout_rng: PRNGKeyArray | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: DTypeLike | None = jnp.bfloat16, softmax_dtype: DTypeLike | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, *, cfg: AttentionConfig, ) -> tuple[Float[Array, "batch seq_len num_heads head_dim"], Float[Array, "batch num_heads seq_len kv_len"]]: """Execute flash attention with the given configuration. Args: query: Query tensor [batch, seq_len_q, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_heads, head_dim] value: Value tensor [batch, seq_len_k, num_heads, head_dim] attention_mask: Optional attention mask (legacy, prefer bias) bias: Optional attention bias tensor softmax_scale: Scaling factor for attention scores dropout_prob: Dropout probability for attention weights sliding_window: Window size for local attention platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Configuration object specifying platform/backend Returns: Attention output [batch, seq_len_q, num_heads, head_dim] """ impl = self.get_impl(cfg) return impl( query=query, key=key, value=value, attention_mask=attention_mask, bias=bias, softmax_scale=softmax_scale, logits_soft_cap=logits_soft_cap, dropout_prob=dropout_prob, init_bias=init_bias, deterministic=deterministic, dropout_rng=dropout_rng, dtype=dtype, softmax_dtype=softmax_dtype, sliding_window=sliding_window, softmax_aux=softmax_aux, causal=causal, )
[docs] def heuristic_cfg(self, inv: Invocation[AttentionConfig, Array]) -> AttentionConfig: """Provide default configuration based on invocation context. Selects optimal block sizes based on sequence length and head dimension. Args: inv: Invocation object with arguments and metadata Returns: Default configuration with block sizes """ return AttentionConfig( block_q=128, block_k=128, num_warps=4, num_stages=2, platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[AttentionConfig, Array]): """Generate candidate configurations for autotuning. This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning. Args: inv: Invocation object with arguments and metadata Returns: Empty list - no candidates to autotune since XLA handles optimization Note: XLA's attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark. """ return []
_executor: Executor[AttentionConfig, tuple[Array, Array]] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy(allow_autotune=True, cache_miss_fallback="heuristics", validate_backward=True), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("attention"), ) )
[docs]def attention( query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch seq_len num_q_heads vhead_dim"], bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, dropout_rng: PRNGKeyArray | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, /, *, mask_info: MaskInfo | None = None, init_bias: tp.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, deterministic: bool = True, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: DTypeLike | None = jnp.bfloat16, softmax_dtype: DTypeLike | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, ) -> Float[Array, "batch seq_len num_q_heads vhead_dim"]: """Execute flash attention with automatic optimization. Convenience function that uses a default executor and flash attention module. Args: query: Query tensor [batch, seq_len, num_heads, head_dim] key: Key tensor [batch, seq_len_k, num_heads, head_dim] value: Value tensor [batch, seq_len_k, num_heads, head_dim] mask_info: Optional MaskInfo containing attention mask and/or segment IDs bias: Optional attention bias tensor softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim)) dropout_prob: Dropout probability for attention weights sliding_window: Window size for local attention (int or (left, right) tuple) platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") Returns: Attention output with same shape as query Example: >>> >>> out = attention(query, key, value) >>> >>> >>> out = attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125) >>> >>> >>> out = attention(query, key, value, platform="xla") """ attention_mask = None if mask_info is not None: attention_mask = mask_info.get_or_compute_attention_mask() return _executor( Attention(), query=query, key=key, value=value, attention_mask=attention_mask, bias=bias, softmax_scale=softmax_scale, logits_soft_cap=logits_soft_cap, dropout_prob=dropout_prob, init_bias=init_bias, deterministic=deterministic, dropout_rng=dropout_rng, dtype=dtype, softmax_dtype=softmax_dtype, sliding_window=sliding_window, softmax_aux=softmax_aux, causal=causal, )