Source code for ejkernel.modules.operations.multi_head_latent_attention

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Multi-head Latent Attention (MLA) module with automatic optimization.

This module implements Multi-head Latent Attention, a memory-efficient attention
variant that uses low-rank compression for key-value pairs. MLA reduces the KV cache
size by projecting keys and values through a low-rank bottleneck while maintaining
attention quality.

The key innovation is compressing the KV representations:
    1. Keys and values are projected to a low-rank space (kv_lora_rank)
    2. Compressed representations are stored efficiently
    3. Full-rank keys/values are reconstructed on-the-fly using learned weights

This is particularly beneficial for:
    - Long context inference where KV cache dominates memory
    - Multi-query or grouped-query attention patterns
    - Deployment scenarios with memory constraints
"""

from __future__ import annotations

import os
from typing import Literal

from jaxtyping import Array, Float, Int

from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
    AutotunePolicy,
    ConfigCache,
    ConfigSelectorChain,
    Executor,
    Invocation,
    Kernel,
    Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache

from ..base import detect_platform
from .configs import FlashMLAConfig


[docs]class FlashMLA(Kernel[FlashMLAConfig, Array]): """Flash Multi-head Latent Attention with custom optimization logic. Combines flash attention's memory efficiency with MLA's low-rank KV compression. This implementation uses tiling and on-the-fly decompression to achieve both reduced memory footprint and computational efficiency. Features: - Low-rank KV compression via w_kc and w_vc weight matrices - Optional RoPE bias for positional encoding (b_q, b_k) - Flash attention-style tiling for memory efficiency - Support for causal masking and variable-length sequences - Multiple platform support (Triton/Pallas/CUDA/XLA) The compression scheme: - key_value: Compressed KV tensor [batch, seq_len, kv_lora_rank] - w_kc, w_vc: Decompression weights [kv_lora_rank, kv_heads, head_dim] - Keys/values are reconstructed as: key = key_value @ w_kc """ def __init__(self): """Initialize Flash MLA module. Sets up the kernel with the operation identifier for registry lookup and configuration management. """ super().__init__(op_id="flash_mla")
[docs] def get_impl(self, cfg: FlashMLAConfig): """Get kernel implementation from registry. Args: cfg: Configuration specifying platform and backend Returns: Callable kernel implementation for flash MLA Raises: ValueError: If no matching implementation is found """ platform = detect_platform("flash_mla", cfg.platform) return kernel_registry.get("flash_mla", platform=platform, backend=cfg.backend)
[docs] def run( self, query: Float[Array, "batch seq_len q_heads head_dim"], key_value: Float[Array, "batch seq_len kv_lora_rank"], w_kc: Float[Array, "kv_lora_rank kv_heads head_dim"], w_vc: Float[Array, "kv_lora_rank kv_heads head_dim"], b_q: Float[Array, "batch seq_len qk_rope_head_dim"] | None = None, b_k: Float[Array, "batch seq_len qk_rope_head_dim"] | None = None, softmax_scale: float | None = None, causal: bool = False, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, *, cfg: FlashMLAConfig, ) -> Float[Array, "batch seq_len q_heads head_dim"]: """Execute flash multi-head latent attention. Args: query: Query tensor [batch, seq_len, q_heads, head_dim] key_value: Compressed key-value tensor [batch, seq_len, kv_lora_rank] w_kc: Key decompression weights [kv_lora_rank, kv_heads, head_dim] w_vc: Value decompression weights [kv_lora_rank, kv_heads, head_dim] b_q: Optional query RoPE bias [batch, seq_len, qk_rope_head_dim] b_k: Optional key RoPE bias [batch, seq_len, qk_rope_head_dim] softmax_scale: Optional scaling factor for attention scores causal: Whether to apply causal masking (default: False) cu_seqlens: Cumulative sequence lengths for variable-length sequences platform: Optional platform override ("triton", "pallas", "cuda", "xla") cfg: Kernel configuration object Returns: Attention output [batch, seq_len, q_heads, head_dim] Note: The kv_lora_rank determines the compression ratio. Lower ranks save more memory but may reduce quality. Typical values: 64-256. """ if platform is not None: cfg = FlashMLAConfig( block_q=cfg.block_q, block_k=cfg.block_k, num_warps=cfg.num_warps, num_stages=cfg.num_stages, platform=platform, backend=Backend.ANY if platform == "xla" else cfg.backend, ) impl = self.get_impl(cfg) return impl( query=query, key_value=key_value, w_kc=w_kc, w_vc=w_vc, b_q=b_q, b_k=b_k, softmax_scale=softmax_scale, causal=causal, cu_seqlens=cu_seqlens, )
[docs] def heuristic_cfg(self, inv: Invocation[FlashMLAConfig, Array]) -> FlashMLAConfig: """Provide default configuration with block sizes. Args: inv: Invocation object containing arguments and metadata Returns: Default configuration optimized for MLA's low-rank decompression and on-the-fly reconstruction requirements """ return FlashMLAConfig( block_q=128, block_k=128, num_warps=4, num_stages=2, platform="auto", backend="any", )
[docs] def candidate_cfgs(self, inv: Invocation[FlashMLAConfig, Array]): """Generate candidate configurations for autotuning. Args: inv: Invocation object containing arguments and metadata Returns: List of candidate configurations to benchmark during autotuning Note: MLA performance depends on the compression rank and decompression overhead. Candidates balance memory efficiency with compute cost. """ block_configs = [ (64, 64, 4, 1), (128, 128, 4, 2), (256, 256, 8, 2), ] candidates = [] for block_q, block_k, num_warps, num_stages in block_configs: candidates.append( FlashMLAConfig( block_q=block_q, block_k=block_k, num_warps=num_warps, num_stages=num_stages, platform="auto", backend="any", ) ) return candidates
_mla_executor: Executor[FlashMLAConfig, Array] = Executor( ConfigSelectorChain( cache=ConfigCache(), policy=AutotunePolicy( allow_autotune=True, cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"), validate_backward=True, ), tuner=Tuner(warmup=5, iters=100), persistent=PersistentCache("mla"), ) )
[docs]def mla_attention( query: Float[Array, "batch seq_len q_heads head_dim"], key_value: Float[Array, "batch seq_len kv_lora_rank"], w_kc: Float[Array, "kv_lora_rank kv_heads head_dim"], w_vc: Float[Array, "kv_lora_rank kv_heads head_dim"], b_q: Float[Array, "batch seq_len qk_rope_head_dim"] | None = None, b_k: Float[Array, "batch seq_len qk_rope_head_dim"] | None = None, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, /, *, softmax_scale: float | None = None, causal: bool = False, platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None, cfg: FlashMLAConfig | None = None, ) -> Float[Array, "batch seq_len q_heads head_dim"]: """Execute flash multi-head latent attention with automatic optimization. MLA uses low-rank compression for key-value pairs to reduce memory and computation while maintaining attention quality. Args: query: Query tensor [batch, seq_len, q_heads, head_dim] key_value: Compressed key-value tensor [batch, seq_len, kv_lora_rank] w_kc: Key compression weights [kv_lora_rank, kv_heads, head_dim] w_vc: Value compression weights [kv_lora_rank, kv_heads, head_dim] b_q: Query RoPE bias [batch, seq_len, qk_rope_head_dim] b_k: Key RoPE bias [batch, seq_len, qk_rope_head_dim] softmax_scale: Scaling factor for attention scores causal: Whether to apply causal masking cu_seqlens: Cumulative sequence lengths for variable-length sequences platform: Specific platform to use ("triton", "pallas", "cuda", or "xla") cfg: Optional kernel configuration override Returns: Attention output with same shape as query Example: >>> >>> out = mla_attention(query, key_value, w_kc, w_vc) >>> >>> >>> out = mla_attention(query, key_value, w_kc, w_vc, causal=True) >>> >>> >>> out = mla_attention(query, key_value, w_kc, w_vc, b_q=q_rope, b_k=k_rope) >>> >>> >>> out = mla_attention(..., platform="triton") """ return _mla_executor( FlashMLA(), query=query, key_value=key_value, w_kc=w_kc, w_vc=w_vc, b_q=b_q, b_k=b_k, softmax_scale=softmax_scale, causal=causal, cu_seqlens=cu_seqlens, platform=platform, _cfg=cfg, )