ejkernel.modules.operations.multi_head_latent_attention#
Multi-head Latent Attention (MLA) module with automatic optimization.
This module implements Multi-head Latent Attention, a memory-efficient attention variant that uses low-rank compression for key-value pairs. MLA reduces the KV cache size by projecting keys and values through a low-rank bottleneck while maintaining attention quality.
- The key innovation is compressing the KV representations:
Keys and values are projected to a low-rank space (kv_lora_rank)
Compressed representations are stored efficiently
Full-rank keys/values are reconstructed on-the-fly using learned weights
- This is particularly beneficial for:
Long context inference where KV cache dominates memory
Multi-query or grouped-query attention patterns
Deployment scenarios with memory constraints
- class ejkernel.modules.operations.multi_head_latent_attention.FlashMLA[source]#
Bases:
Kernel[FlashMLAConfig,Array]Flash Multi-head Latent Attention with custom optimization logic.
Combines flash attention’s memory efficiency with MLA’s low-rank KV compression. This implementation uses tiling and on-the-fly decompression to achieve both reduced memory footprint and computational efficiency.
- Features:
Low-rank KV compression via w_kc and w_vc weight matrices
Optional RoPE bias for positional encoding (b_q, b_k)
Flash attention-style tiling for memory efficiency
Support for causal masking and variable-length sequences
Multiple platform support (Triton/Pallas/CUDA/XLA)
- The compression scheme:
key_value: Compressed KV tensor [batch, seq_len, kv_lora_rank]
w_kc, w_vc: Decompression weights [kv_lora_rank, kv_heads, head_dim]
Keys/values are reconstructed as: key = key_value @ w_kc
- candidate_cfgs(inv: Invocation[FlashMLAConfig, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
MLA performance depends on the compression rank and decompression overhead. Candidates balance memory efficiency with compute cost.
- get_impl(cfg: FlashMLAConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for flash MLA
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[FlashMLAConfig, Array]) FlashMLAConfig[source]#
Provide default configuration with block sizes.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration optimized for MLA’s low-rank decompression and on-the-fly reconstruction requirements
- run(query: Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'], key_value: Float[jaxlib._jax.Array, 'batch seq_len kv_lora_rank'], w_kc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], w_vc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], b_q: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, b_k: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, softmax_scale: float | None = None, causal: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: FlashMLAConfig) Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'][source]#
Execute flash multi-head latent attention.
- Parameters
query – Query tensor [batch, seq_len, q_heads, head_dim]
key_value – Compressed key-value tensor [batch, seq_len, kv_lora_rank]
w_kc – Key decompression weights [kv_lora_rank, kv_heads, head_dim]
w_vc – Value decompression weights [kv_lora_rank, kv_heads, head_dim]
b_q – Optional query RoPE bias [batch, seq_len, qk_rope_head_dim]
b_k – Optional key RoPE bias [batch, seq_len, qk_rope_head_dim]
softmax_scale – Optional scaling factor for attention scores
causal – Whether to apply causal masking (default: False)
cu_seqlens – Cumulative sequence lengths for variable-length sequences
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Attention output [batch, seq_len, q_heads, head_dim]
Note
The kv_lora_rank determines the compression ratio. Lower ranks save more memory but may reduce quality. Typical values: 64-256.
- ejkernel.modules.operations.multi_head_latent_attention.mla_attention(query: Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'], key_value: Float[jaxlib._jax.Array, 'batch seq_len kv_lora_rank'], w_kc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], w_vc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], b_q: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, b_k: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, causal: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.FlashMLAConfig | None = None) Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'][source]#
Execute flash multi-head latent attention with automatic optimization.
MLA uses low-rank compression for key-value pairs to reduce memory and computation while maintaining attention quality.
- Parameters
query – Query tensor [batch, seq_len, q_heads, head_dim]
key_value – Compressed key-value tensor [batch, seq_len, kv_lora_rank]
w_kc – Key compression weights [kv_lora_rank, kv_heads, head_dim]
w_vc – Value compression weights [kv_lora_rank, kv_heads, head_dim]
b_q – Query RoPE bias [batch, seq_len, qk_rope_head_dim]
b_k – Key RoPE bias [batch, seq_len, qk_rope_head_dim]
softmax_scale – Scaling factor for attention scores
causal – Whether to apply causal masking
cu_seqlens – Cumulative sequence lengths for variable-length sequences
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional kernel configuration override
- Returns
Attention output with same shape as query
Example
>>> >>> out = mla_attention(query, key_value, w_kc, w_vc) >>> >>> >>> out = mla_attention(query, key_value, w_kc, w_vc, causal=True) >>> >>> >>> out = mla_attention(query, key_value, w_kc, w_vc, b_q=q_rope, b_k=k_rope) >>> >>> >>> out = mla_attention(..., platform="triton")