ejkernel.kernels._triton.flash_attention._interface

ejkernel.kernels._triton.flash_attention._interface#

Flash Attention implementation using Triton kernels.

This module provides a highly optimized implementation of Flash Attention, an IO-aware exact attention algorithm that reduces memory usage from O(N²) to O(N) through tiling and recomputation strategies.

Flash Attention is a breakthrough in efficient attention computation that maintains exact attention semantics while dramatically reducing memory footprint. The key insight is to split the attention computation into blocks and fuse operations to minimize memory reads/writes between GPU HBM and SRAM.

Key advantages over standard attention: 1. Subquadratic memory: O(N) instead of O(N²) for sequence length N 2. Faster wall-clock time: Reduced memory I/O translates to speed improvements 3. Exact attention: No approximation, produces identical results to standard attention 4. Better scaling: Enables processing of much longer sequences

Algorithm overview: - Query and key-value sequences are split into blocks - Attention is computed block-by-block using online softmax - Partial results are accumulated incrementally - No full attention matrix is ever materialized

Supported features: - Causal and non-causal attention - Attention bias and masking - Dropout during training - Variable-length sequences (via cu_seqlens) - Sliding window attention for local patterns - Grouped-query attention (GQA) and multi-query attention (MQA) - Attention sinks via softmax_aux parameter - Logits soft capping for numerical stability

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.flash_attention import flash_attention
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 2048, 12, 64
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>>
>>> output = flash_attention(q, k, v)
>>>
>>>
>>> output = flash_attention(q, k, v, causal=True)
>>>
>>>
>>> output = flash_attention(q, k, v, dropout_prob=0.1, dropout_seed=42)
Reference:

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness https://arxiv.org/abs/2205.14135

ejkernel.kernels._triton.flash_attention._interface.flash_attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, *, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Compute flash attention for efficient scaled dot-product attention.

Flash Attention is a memory-efficient and fast implementation of exact attention that uses tiling and recomputation to reduce memory usage from O(N²) to O(N) where N is sequence length.

Parameters
  • query – Query tensor of shape [batch, seq_len, num_heads, head_dim]

  • key – Key tensor of shape [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor of shape [batch, seq_len_k, num_heads, head_dim]

  • attention_mask – Optional attention mask (legacy, prefer bias parameter)

  • bias – Attention bias for masking or relative position encoding

  • softmax_scale – Scaling factor for QK^T (default: 1/sqrt(head_dim))

  • dropout_prob – Dropout probability for attention weights (0-1)

  • causal – Whether to apply causal masking for autoregressive models

  • dropout_seed – Random seed for reproducible dropout

  • cum_seqlens_q – Cumulative sequence lengths for packed variable-length sequences

  • cum_seqlens_k – Cumulative sequence lengths for keys in variable-length mode

  • sliding_window – Size of local attention window for sparse patterns

  • logits_soft_cap – Optional soft cap value for logits (e.g., 20.0 for Gemma)

  • softmax_aux – Optional attention sink logits of shape [H, num_sinks] or [num_sinks]

  • q_segment_ids/kv_segment_ids – Optional packed-sequence segment IDs (mask cross-segment attention)

Returns

Attention output with shape [batch, seq_len, num_heads, head_dim]

Return type

chex.Array

Examples

>>>
>>> out = flash_attention(query, key, value, causal=True)
>>>
>>>
>>> out = flash_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125)
>>>
>>>
>>> out = flash_attention(query, key, value, cum_seqlens_q=cum_lens, cum_seqlens_k=cum_lens)