ejkernel.modules.operations.flash_attention#

Flash Attention module with automatic optimization.

This module implements Flash Attention, a memory-efficient attention mechanism that uses tiling and recomputation to achieve O(N) memory complexity instead of the standard O(N²) for sequence length N.

Key features of Flash Attention:
  • Memory-efficient: Uses tiling to process attention in blocks

  • IO-aware: Minimizes HBM (high bandwidth memory) accesses

  • Exact: Produces numerically identical results to standard attention

  • Fast: Often faster than standard attention despite recomputation

The algorithm works by:
  1. Splitting Q, K, V into blocks along sequence dimension

  2. Computing attention block-by-block with on-the-fly softmax

  3. Using online softmax correction for numerical stability

  4. Fusing operations to minimize memory transfers

Supports:
  • Causal and non-causal masking

  • Variable sequence lengths via cumulative sequence lengths

  • Dropout (during training)

  • Sliding window attention

  • Multi-query and grouped-query attention patterns

  • Attention biasing and soft capping

Mathematical formulation:

Standard: Attention(Q,K,V) = softmax(QK^T/√d)V Flash: Same output, but computed in O(N) memory via tiling

Reference:

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) https://arxiv.org/abs/2205.14135

class ejkernel.modules.operations.flash_attention.FlashAttention[source]#

Bases: Kernel[FlashAttentionConfig, Array]

Flash Attention with custom optimization logic.

Memory-efficient exact attention with O(N) memory complexity. Supports causal masking, dropout, sliding windows, and variable-length sequences.

Features:
  • Automatic platform/backend selection (Triton/Pallas/XLA)

  • Configuration caching for consistent performance

  • Optional autotuning to find optimal implementation

  • Custom gradient support for efficient backpropagation

  • Support for variable-length sequences via cumulative sequence lengths

  • Sliding window attention for local attention patterns

  • Logits soft capping for numerical stability

Example

>>> from ejkernel.modules import FlashAttention, create_default_executor
>>>
>>>
>>> executor = create_default_executor()
>>> attn = FlashAttention()
>>>
>>>
>>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125)
>>>
>>>
>>> output = executor(
...     attn, query, key, value,
...     cum_seqlens_q=cu_seqlens_q,
...     cum_seqlens_k=cu_seqlens_k
... )
>>>
>>>
>>> output = executor(attn, query, key, value, sliding_window=(256, 256))
candidate_cfgs(inv: Invocation[FlashAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates multiple block size configurations for benchmarking to find the optimal tiling parameters for the given input shapes.

Parameters

inv – Invocation object with arguments and metadata

Returns

Iterable of candidate configurations to test during autotuning

Note

The autotuning system will benchmark each candidate and select the fastest one for the given input configuration.

candidate_cfgs_gpu(inv: Invocation[FlashAttentionConfig, Array])[source]#

Generate GPU-optimized candidate configurations for autotuning (Triton).

Heuristics: - q/kv blocks adapt to head_dim and sequence lengths. - If sliding_window is set, kv blocks are capped near the window span. - num_warps: 2-8 based on head_dim and block sizes. - num_stages: 2-3 (kept low to reduce SMEM pressure). - Conservative shared-memory guard to avoid CUDA errors. - Backward blocks smaller to reduce register pressure.

candidate_cfgs_shard_map_gpu(inv: Invocation[FlashAttentionConfig, Array])#

Generate GPU-optimized candidate configurations for autotuning (Triton).

Heuristics: - q/kv blocks adapt to head_dim and sequence lengths. - If sliding_window is set, kv blocks are capped near the window span. - num_warps: 2-8 based on head_dim and block sizes. - num_stages: 2-3 (kept low to reduce SMEM pressure). - Conservative shared-memory guard to avoid CUDA errors. - Backward blocks smaller to reduce register pressure.

candidate_cfgs_shard_map_tpu(inv: Invocation[FlashAttentionConfig, Array])#

Generate TPU-optimized candidate configurations for autotuning (Pallas).

Heuristics: - Favor moderate Q blocks (32-128) and KV blocks (64-256/512). - If sliding_window is set, prefer kv blocks ≲ window span. - Slightly smaller backward blocks to reduce VMEM/regs. - Keep the candidate list compact and ordered for fast convergence.

candidate_cfgs_shard_map_xla(inv: Invocation[FlashAttentionConfig, Array])#

Generate XLA-optimized candidate configurations for autotuning.

Heuristics: - Medium blocks (128-256) tend to be robust. - If sliding_window is set, keep kv blocks near window span. - Backward tiles are smaller. - Keep list small and ordered by likely winners.

candidate_cfgs_tpu(inv: Invocation[FlashAttentionConfig, Array])[source]#

Generate TPU-optimized candidate configurations for autotuning (Pallas).

Heuristics: - Favor moderate Q blocks (32-128) and KV blocks (64-256/512). - If sliding_window is set, prefer kv blocks ≲ window span. - Slightly smaller backward blocks to reduce VMEM/regs. - Keep the candidate list compact and ordered for fast convergence.

candidate_cfgs_xla(inv: Invocation[FlashAttentionConfig, Array])[source]#

Generate XLA-optimized candidate configurations for autotuning.

Heuristics: - Medium blocks (128-256) tend to be robust. - If sliding_window is set, keep kv blocks near window span. - Backward tiles are smaller. - Keep list small and ordered by likely winners.

create_shard_map_wrapper(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, cfg: ejkernel.modules.operations.configs.FlashAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper specifically for flash attention.

Parameters
  • query – Input tensors to be sharded

  • key – Input tensors to be sharded

  • value – Input tensors to be sharded

  • mesh – JAX device mesh

  • in_specs – Input partition specs (for q, k, v, and optionally mask/bias)

  • out_specs – Output partition spec

  • args (All other) – Flash attention parameters to be fixed via partial

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: FlashAttentionConfig)[source]#

Get kernel implementation from registry based on configuration.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

heuristic_cfg_gpu(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

heuristic_cfg_tpu(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

run(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, cfg: ~ejkernel.modules.operations.configs.FlashAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute flash attention with the given configuration.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • attention_mask – Optional attention mask (legacy, prefer bias)

  • bias – Optional attention bias tensor

  • softmax_scale – Scaling factor for attention scores

  • dropout_prob – Dropout probability for attention weights

  • causal – Whether to apply causal masking

  • dropout_seed – Random seed for dropout

  • cum_seqlens_q – Cumulative sequence lengths for variable-length queries

  • cum_seqlens_k – Cumulative sequence lengths for variable-length keys

  • sliding_window – Window size for local attention

  • logits_soft_cap – Optional soft cap value for logits

  • softmax_aux – Optional attention sink logits

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Configuration object specifying platform/backend

  • segment_ids – Segment IDs for grouped sequences (TPU-specific)

  • block_sizes – Block sizes for kernel execution (TPU-specific)

Returns

Attention output [batch, seq_len_q, num_heads, head_dim]

version: str = '1'#
ejkernel.modules.operations.flash_attention.flash_attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.FlashAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute flash attention with automatic optimization.

Convenience function that uses a default executor and flash attention module.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • mask_info – Optional MaskInfo containing attention mask and/or segment IDs

  • bias – Optional attention bias tensor

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • dropout_prob – Dropout probability for attention weights

  • causal – Whether to apply causal masking

  • dropout_seed – Random seed for dropout

  • cum_seqlens_q – Cumulative sequence lengths for variable-length queries

  • cum_seqlens_k – Cumulative sequence lengths for variable-length keys

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • logits_soft_cap – Optional soft cap value for logits

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional configuration override

  • mesh – JAX device mesh for shard_map execution (optional)

  • in_specs – Input partition specs for shard_map (optional)

  • out_specs – Output partition spec for shard_map (optional)

Returns

Attention output with same shape as query

Example

>>>
>>> out = flash_attention(query, key, value, causal=True)
>>>
>>>
>>> out = flash_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125)
>>>
>>>
>>> out = flash_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k)
>>>
>>>
>>> out = flash_attention(query, key, value, platform="triton")