ejkernel.kernels._xla.flash_attention._interface

ejkernel.kernels._xla.flash_attention._interface#

Flash Attention interface for XLA backend.

This module provides the public API for Flash Attention using XLA, including the main flash_attention function with custom VJP support.

The implementation supports:
  • Configurable chunk sizes for query and key processing

  • Causal and non-causal attention modes

  • Sliding window attention

  • Attention masks and bias tensors

  • Segment IDs for packed sequence processing

  • Dropout with reproducible randomness

  • Multiple precision modes (DEFAULT, HIGH, HIGHEST)

Internal Functions:

_make_core_func: Creates specialized attention cores for given static params _precision_to_code: Convert JAX precision to integer code _dtype_to_code: Convert dtype to integer code for JIT compilation

ejkernel.kernels._xla.flash_attention._interface.flash_attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, *, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Flash attention with memory-efficient chunked computation and attention sinks.

This implementation uses online softmax to compute attention in chunks, reducing memory usage from O(N²) to O(N). Supports sliding window attention, logit soft capping, grouped query attention (GQA/MQA), and attention sinks.