ejkernel.modules.operations.multi_head_latent_attention

ejkernel.modules.operations.multi_head_latent_attention#

Multi-head Latent Attention (MLA) module with automatic optimization.

This module implements Multi-head Latent Attention, a memory-efficient attention variant that uses low-rank compression for key-value pairs. MLA reduces the KV cache size by projecting keys and values through a low-rank bottleneck while maintaining attention quality.

The key innovation is compressing the KV representations:
  1. Keys and values are projected to a low-rank space (kv_lora_rank)

  2. Compressed representations are stored efficiently

  3. Full-rank keys/values are reconstructed on-the-fly using learned weights

This is particularly beneficial for:
  • Long context inference where KV cache dominates memory

  • Multi-query or grouped-query attention patterns

  • Deployment scenarios with memory constraints

class ejkernel.modules.operations.multi_head_latent_attention.FlashMLA[source]#

Bases: Kernel[FlashMLAConfig, Array]

Flash Multi-head Latent Attention with custom optimization logic.

Combines flash attention’s memory efficiency with MLA’s low-rank KV compression. This implementation uses tiling and on-the-fly decompression to achieve both reduced memory footprint and computational efficiency.

Features:
  • Low-rank KV compression via w_kc and w_vc weight matrices

  • Optional RoPE bias for positional encoding (b_q, b_k)

  • Flash attention-style tiling for memory efficiency

  • Support for causal masking and variable-length sequences

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

The compression scheme:
  • key_value: Compressed KV tensor [batch, seq_len, kv_lora_rank]

  • w_kc, w_vc: Decompression weights [kv_lora_rank, kv_heads, head_dim]

  • Keys/values are reconstructed as: key = key_value @ w_kc

candidate_cfgs(inv: Invocation[FlashMLAConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

MLA performance depends on the compression rank and decompression overhead. Candidates balance memory efficiency with compute cost.

get_impl(cfg: FlashMLAConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for flash MLA

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[FlashMLAConfig, Array]) FlashMLAConfig[source]#

Provide default configuration with block sizes.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration optimized for MLA’s low-rank decompression and on-the-fly reconstruction requirements

run(query: Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'], key_value: Float[jaxlib._jax.Array, 'batch seq_len kv_lora_rank'], w_kc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], w_vc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], b_q: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, b_k: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, softmax_scale: float | None = None, causal: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: FlashMLAConfig) Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'][source]#

Execute flash multi-head latent attention.

Parameters
  • query – Query tensor [batch, seq_len, q_heads, head_dim]

  • key_value – Compressed key-value tensor [batch, seq_len, kv_lora_rank]

  • w_kc – Key decompression weights [kv_lora_rank, kv_heads, head_dim]

  • w_vc – Value decompression weights [kv_lora_rank, kv_heads, head_dim]

  • b_q – Optional query RoPE bias [batch, seq_len, qk_rope_head_dim]

  • b_k – Optional key RoPE bias [batch, seq_len, qk_rope_head_dim]

  • softmax_scale – Optional scaling factor for attention scores

  • causal – Whether to apply causal masking (default: False)

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Attention output [batch, seq_len, q_heads, head_dim]

Note

The kv_lora_rank determines the compression ratio. Lower ranks save more memory but may reduce quality. Typical values: 64-256.

ejkernel.modules.operations.multi_head_latent_attention.mla_attention(query: Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'], key_value: Float[jaxlib._jax.Array, 'batch seq_len kv_lora_rank'], w_kc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], w_vc: Float[jaxlib._jax.Array, 'kv_lora_rank kv_heads head_dim'], b_q: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, b_k: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len qk_rope_head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, causal: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.FlashMLAConfig | None = None) Float[jaxlib._jax.Array, 'batch seq_len q_heads head_dim'][source]#

Execute flash multi-head latent attention with automatic optimization.

MLA uses low-rank compression for key-value pairs to reduce memory and computation while maintaining attention quality.

Parameters
  • query – Query tensor [batch, seq_len, q_heads, head_dim]

  • key_value – Compressed key-value tensor [batch, seq_len, kv_lora_rank]

  • w_kc – Key compression weights [kv_lora_rank, kv_heads, head_dim]

  • w_vc – Value compression weights [kv_lora_rank, kv_heads, head_dim]

  • b_q – Query RoPE bias [batch, seq_len, qk_rope_head_dim]

  • b_k – Key RoPE bias [batch, seq_len, qk_rope_head_dim]

  • softmax_scale – Scaling factor for attention scores

  • causal – Whether to apply causal masking

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional kernel configuration override

Returns

Attention output with same shape as query

Example

>>>
>>> out = mla_attention(query, key_value, w_kc, w_vc)
>>>
>>>
>>> out = mla_attention(query, key_value, w_kc, w_vc, causal=True)
>>>
>>>
>>> out = mla_attention(query, key_value, w_kc, w_vc, b_q=q_rope, b_k=k_rope)
    >>>
>>>
>>> out = mla_attention(..., platform="triton")