ejkernel.modules.operations.flash_attention#
Flash Attention module with automatic optimization.
This module implements Flash Attention, a memory-efficient attention mechanism that uses tiling and recomputation to achieve O(N) memory complexity instead of the standard O(N²) for sequence length N.
- Key features of Flash Attention:
Memory-efficient: Uses tiling to process attention in blocks
IO-aware: Minimizes HBM (high bandwidth memory) accesses
Exact: Produces numerically identical results to standard attention
Fast: Often faster than standard attention despite recomputation
- The algorithm works by:
Splitting Q, K, V into blocks along sequence dimension
Computing attention block-by-block with on-the-fly softmax
Using online softmax correction for numerical stability
Fusing operations to minimize memory transfers
- Supports:
Causal and non-causal masking
Variable sequence lengths via cumulative sequence lengths
Dropout (during training)
Sliding window attention
Multi-query and grouped-query attention patterns
Attention biasing and soft capping
- Mathematical formulation:
Standard: Attention(Q,K,V) = softmax(QK^T/√d)V Flash: Same output, but computed in O(N) memory via tiling
- Reference:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) https://arxiv.org/abs/2205.14135
- class ejkernel.modules.operations.flash_attention.FlashAttention[source]#
Bases:
Kernel[FlashAttentionConfig,Array]Flash Attention with custom optimization logic.
Memory-efficient exact attention with O(N) memory complexity. Supports causal masking, dropout, sliding windows, and variable-length sequences.
- Features:
Automatic platform/backend selection (Triton/Pallas/XLA)
Configuration caching for consistent performance
Optional autotuning to find optimal implementation
Custom gradient support for efficient backpropagation
Support for variable-length sequences via cumulative sequence lengths
Sliding window attention for local attention patterns
Logits soft capping for numerical stability
Example
>>> from ejkernel.modules import FlashAttention, create_default_executor >>> >>> >>> executor = create_default_executor() >>> attn = FlashAttention() >>> >>> >>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125) >>> >>> >>> output = executor( ... attn, query, key, value, ... cum_seqlens_q=cu_seqlens_q, ... cum_seqlens_k=cu_seqlens_k ... ) >>> >>> >>> output = executor(attn, query, key, value, sliding_window=(256, 256))
- candidate_cfgs(inv: Invocation[FlashAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates multiple block size configurations for benchmarking to find the optimal tiling parameters for the given input shapes.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Iterable of candidate configurations to test during autotuning
Note
The autotuning system will benchmark each candidate and select the fastest one for the given input configuration.
- candidate_cfgs_gpu(inv: Invocation[FlashAttentionConfig, Array])[source]#
Generate GPU-optimized candidate configurations for autotuning (Triton).
Heuristics: - q/kv blocks adapt to head_dim and sequence lengths. - If sliding_window is set, kv blocks are capped near the window span. - num_warps: 2-8 based on head_dim and block sizes. - num_stages: 2-3 (kept low to reduce SMEM pressure). - Conservative shared-memory guard to avoid CUDA errors. - Backward blocks smaller to reduce register pressure.
- candidate_cfgs_shard_map_gpu(inv: Invocation[FlashAttentionConfig, Array])#
Generate GPU-optimized candidate configurations for autotuning (Triton).
Heuristics: - q/kv blocks adapt to head_dim and sequence lengths. - If sliding_window is set, kv blocks are capped near the window span. - num_warps: 2-8 based on head_dim and block sizes. - num_stages: 2-3 (kept low to reduce SMEM pressure). - Conservative shared-memory guard to avoid CUDA errors. - Backward blocks smaller to reduce register pressure.
- candidate_cfgs_shard_map_tpu(inv: Invocation[FlashAttentionConfig, Array])#
Generate TPU-optimized candidate configurations for autotuning (Pallas).
Heuristics: - Favor moderate Q blocks (32-128) and KV blocks (64-256/512). - If sliding_window is set, prefer kv blocks ≲ window span. - Slightly smaller backward blocks to reduce VMEM/regs. - Keep the candidate list compact and ordered for fast convergence.
- candidate_cfgs_shard_map_xla(inv: Invocation[FlashAttentionConfig, Array])#
Generate XLA-optimized candidate configurations for autotuning.
Heuristics: - Medium blocks (128-256) tend to be robust. - If sliding_window is set, keep kv blocks near window span. - Backward tiles are smaller. - Keep list small and ordered by likely winners.
- candidate_cfgs_tpu(inv: Invocation[FlashAttentionConfig, Array])[source]#
Generate TPU-optimized candidate configurations for autotuning (Pallas).
Heuristics: - Favor moderate Q blocks (32-128) and KV blocks (64-256/512). - If sliding_window is set, prefer kv blocks ≲ window span. - Slightly smaller backward blocks to reduce VMEM/regs. - Keep the candidate list compact and ordered for fast convergence.
- candidate_cfgs_xla(inv: Invocation[FlashAttentionConfig, Array])[source]#
Generate XLA-optimized candidate configurations for autotuning.
Heuristics: - Medium blocks (128-256) tend to be robust. - If sliding_window is set, keep kv blocks near window span. - Backward tiles are smaller. - Keep list small and ordered by likely winners.
- create_shard_map_wrapper(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, cfg: ejkernel.modules.operations.configs.FlashAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper specifically for flash attention.
- Parameters
query – Input tensors to be sharded
key – Input tensors to be sharded
value – Input tensors to be sharded
mesh – JAX device mesh
in_specs – Input partition specs (for q, k, v, and optionally mask/bias)
out_specs – Output partition spec
args (All other) – Flash attention parameters to be fixed via partial
- Returns
Tuple of (shard_map_fn, call_args)
- get_impl(cfg: FlashAttentionConfig)[source]#
Get kernel implementation from registry based on configuration.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#
Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration with block sizes
- heuristic_cfg_gpu(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#
Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration with block sizes
- heuristic_cfg_tpu(inv: Invocation[FlashAttentionConfig, Array]) FlashAttentionConfig[source]#
Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration with block sizes
- run(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len_q seq_len_k'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, cfg: ~ejkernel.modules.operations.configs.FlashAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute flash attention with the given configuration.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
attention_mask – Optional attention mask (legacy, prefer bias)
bias – Optional attention bias tensor
softmax_scale – Scaling factor for attention scores
dropout_prob – Dropout probability for attention weights
causal – Whether to apply causal masking
dropout_seed – Random seed for dropout
cum_seqlens_q – Cumulative sequence lengths for variable-length queries
cum_seqlens_k – Cumulative sequence lengths for variable-length keys
sliding_window – Window size for local attention
logits_soft_cap – Optional soft cap value for logits
softmax_aux – Optional attention sink logits
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Configuration object specifying platform/backend
segment_ids – Segment IDs for grouped sequences (TPU-specific)
block_sizes – Block sizes for kernel execution (TPU-specific)
- Returns
Attention output [batch, seq_len_q, num_heads, head_dim]
- version: str = '1'#
- ejkernel.modules.operations.flash_attention.flash_attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len_k num_kv_heads head_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, normalize_output: bool = True, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT, logits_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, platform: ~typing.Optional[~typing.Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.FlashAttentionConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute flash attention with automatic optimization.
Convenience function that uses a default executor and flash attention module.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
mask_info – Optional MaskInfo containing attention mask and/or segment IDs
bias – Optional attention bias tensor
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
dropout_prob – Dropout probability for attention weights
causal – Whether to apply causal masking
dropout_seed – Random seed for dropout
cum_seqlens_q – Cumulative sequence lengths for variable-length queries
cum_seqlens_k – Cumulative sequence lengths for variable-length keys
sliding_window – Window size for local attention (int or (left, right) tuple)
logits_soft_cap – Optional soft cap value for logits
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional configuration override
mesh – JAX device mesh for shard_map execution (optional)
in_specs – Input partition specs for shard_map (optional)
out_specs – Output partition spec for shard_map (optional)
- Returns
Attention output with same shape as query
Example
>>> >>> out = flash_attention(query, key, value, causal=True) >>> >>> >>> out = flash_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125) >>> >>> >>> out = flash_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k) >>> >>> >>> out = flash_attention(query, key, value, platform="triton")