ejkernel.kernels._triton.flash_mla._interface

ejkernel.kernels._triton.flash_mla._interface#

Interface for Flash Multi-Latent Attention (MLA) operations.

ejkernel.kernels._triton.flash_mla._interface.flash_mla_attention(query: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], key: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], value: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], latent_key: Float[jaxlib._jax.Array, 'head_dim latent_dim'], latent_value: Float[jaxlib._jax.Array, 'head_dim latent_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len seq_len'] | None = None, causal: bool = False, softmax_scale: float | None = None) Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'][source]#

Multi-Latent Attention with automatic differentiation support.

This function wraps flash_mla_attention_call with JAX’s custom gradient support for efficient backpropagation through the attention operation.

Parameters
  • query – Query tensor of shape (batch, heads, seq_len, head_dim).

  • key – Key tensor of shape (batch, heads, seq_len, head_dim).

  • value – Value tensor of shape (batch, heads, seq_len, head_dim).

  • latent_key – Latent key projection matrix of shape (head_dim, latent_dim).

  • latent_value – Latent value projection matrix of shape (head_dim, latent_dim).

  • bias – Optional attention bias of shape (batch, heads, seq_len, seq_len).

  • causal – Whether to apply causal masking.

  • softmax_scale – Scale factor for softmax. Defaults to 1/sqrt(head_dim).

Returns

Output tensor of shape (batch, heads, seq_len, head_dim).

ejkernel.kernels._triton.flash_mla._interface.flash_mla_attention_call(query: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], key: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], value: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], latent_key: Float[jaxlib._jax.Array, 'head_dim latent_dim'], latent_value: Float[jaxlib._jax.Array, 'head_dim latent_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len seq_len'] | None = None, causal: bool = False, softmax_scale: float | None = None) Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'][source]#

Execute Multi-Latent Attention using Triton kernels.

Multi-Latent Attention reduces memory and computation by projecting key and value tensors to lower-dimensional latent spaces before computing attention.

Parameters
  • query – Query tensor of shape (batch, heads, seq_len, head_dim).

  • key – Key tensor of shape (batch, heads, seq_len, head_dim).

  • value – Value tensor of shape (batch, heads, seq_len, head_dim).

  • latent_key – Latent key projection matrix of shape (head_dim, latent_dim).

  • latent_value – Latent value projection matrix of shape (head_dim, latent_dim).

  • bias – Optional attention bias of shape (batch, heads, seq_len, seq_len).

  • causal – Whether to apply causal masking.

  • softmax_scale – Scale factor for softmax. Defaults to 1/sqrt(head_dim).

Returns

Output tensor of shape (batch, heads, seq_len, head_dim).