ejkernel.kernels._triton.flash_mla._interface#
Interface for Flash Multi-Latent Attention (MLA) operations.
- ejkernel.kernels._triton.flash_mla._interface.flash_mla_attention(query: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], key: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], value: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], latent_key: Float[jaxlib._jax.Array, 'head_dim latent_dim'], latent_value: Float[jaxlib._jax.Array, 'head_dim latent_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len seq_len'] | None = None, causal: bool = False, softmax_scale: float | None = None) Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'][source]#
Multi-Latent Attention with automatic differentiation support.
This function wraps flash_mla_attention_call with JAX’s custom gradient support for efficient backpropagation through the attention operation.
- Parameters
query – Query tensor of shape (batch, heads, seq_len, head_dim).
key – Key tensor of shape (batch, heads, seq_len, head_dim).
value – Value tensor of shape (batch, heads, seq_len, head_dim).
latent_key – Latent key projection matrix of shape (head_dim, latent_dim).
latent_value – Latent value projection matrix of shape (head_dim, latent_dim).
bias – Optional attention bias of shape (batch, heads, seq_len, seq_len).
causal – Whether to apply causal masking.
softmax_scale – Scale factor for softmax. Defaults to 1/sqrt(head_dim).
- Returns
Output tensor of shape (batch, heads, seq_len, head_dim).
- ejkernel.kernels._triton.flash_mla._interface.flash_mla_attention_call(query: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], key: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], value: Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'], latent_key: Float[jaxlib._jax.Array, 'head_dim latent_dim'], latent_value: Float[jaxlib._jax.Array, 'head_dim latent_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len seq_len'] | None = None, causal: bool = False, softmax_scale: float | None = None) Float[jaxlib._jax.Array, 'batch num_heads seq_len head_dim'][source]#
Execute Multi-Latent Attention using Triton kernels.
Multi-Latent Attention reduces memory and computation by projecting key and value tensors to lower-dimensional latent spaces before computing attention.
- Parameters
query – Query tensor of shape (batch, heads, seq_len, head_dim).
key – Key tensor of shape (batch, heads, seq_len, head_dim).
value – Value tensor of shape (batch, heads, seq_len, head_dim).
latent_key – Latent key projection matrix of shape (head_dim, latent_dim).
latent_value – Latent value projection matrix of shape (head_dim, latent_dim).
bias – Optional attention bias of shape (batch, heads, seq_len, seq_len).
causal – Whether to apply causal masking.
softmax_scale – Scale factor for softmax. Defaults to 1/sqrt(head_dim).
- Returns
Output tensor of shape (batch, heads, seq_len, head_dim).