ejkernel.kernels._xla.flash_attention._interface#
Flash Attention interface for XLA backend.
This module provides the public API for Flash Attention using XLA, including the main flash_attention function with custom VJP support.
- The implementation supports:
Configurable chunk sizes for query and key processing
Causal and non-causal attention modes
Sliding window attention
Attention masks and bias tensors
Segment IDs for packed sequence processing
Dropout with reproducible randomness
Multiple precision modes (DEFAULT, HIGH, HIGHEST)
- Internal Functions:
_make_core_func: Creates specialized attention cores for given static params _precision_to_code: Convert JAX precision to integer code _dtype_to_code: Convert dtype to integer code for JIT compilation
- ejkernel.kernels._xla.flash_attention._interface.flash_attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, *, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Flash attention with memory-efficient chunked computation and attention sinks.
This implementation uses online softmax to compute attention in chunks, reducing memory usage from O(N²) to O(N). Supports sliding window attention, logit soft capping, grouped query attention (GQA/MQA), and attention sinks.