# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flash Attention module with automatic optimization.
This module implements Flash Attention, a memory-efficient attention mechanism
that uses tiling and recomputation to achieve O(N) memory complexity instead
of the standard O(N²) for sequence length N.
Key features of Flash Attention:
- Memory-efficient: Uses tiling to process attention in blocks
- IO-aware: Minimizes HBM (high bandwidth memory) accesses
- Exact: Produces numerically identical results to standard attention
- Fast: Often faster than standard attention despite recomputation
The algorithm works by:
1. Splitting Q, K, V into blocks along sequence dimension
2. Computing attention block-by-block with on-the-fly softmax
3. Using online softmax correction for numerical stability
4. Fusing operations to minimize memory transfers
Supports:
- Causal and non-causal masking
- Variable sequence lengths via cumulative sequence lengths
- Dropout (during training)
- Sliding window attention
- Multi-query and grouped-query attention patterns
- Attention biasing and soft capping
Mathematical formulation:
Standard: Attention(Q,K,V) = softmax(QK^T/√d)V
Flash: Same output, but computed in O(N) memory via tiling
Reference:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
(Dao et al., 2022) https://arxiv.org/abs/2205.14135
"""
from __future__ import annotations
import math
import os
from typing import Literal
from jax import lax, shard_map
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jaxtyping import Array, Bool, DTypeLike, Float, Int
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
BwdParams,
ConfigCache,
ConfigSelectorChain,
Executor,
FwdParams,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ejkernel.types.mask import MaskInfo
from ..base import detect_platform
from .configs import FlashAttentionConfig
[docs]class FlashAttention(Kernel[FlashAttentionConfig, Array]):
"""Flash Attention with custom optimization logic.
Memory-efficient exact attention with O(N) memory complexity.
Supports causal masking, dropout, sliding windows, and variable-length sequences.
Features:
- Automatic platform/backend selection (Triton/Pallas/XLA)
- Configuration caching for consistent performance
- Optional autotuning to find optimal implementation
- Custom gradient support for efficient backpropagation
- Support for variable-length sequences via cumulative sequence lengths
- Sliding window attention for local attention patterns
- Logits soft capping for numerical stability
Example:
>>> from ejkernel.modules import FlashAttention, create_default_executor
>>>
>>>
>>> executor = create_default_executor()
>>> attn = FlashAttention()
>>>
>>>
>>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125)
>>>
>>>
>>> output = executor(
... attn, query, key, value,
... cum_seqlens_q=cu_seqlens_q,
... cum_seqlens_k=cu_seqlens_k
... )
>>>
>>>
>>> output = executor(attn, query, key, value, sliding_window=(256, 256))
"""
# Bump for persistent-cache invalidation: packed (segment-id) support and
# Triton kernel parameter changes.
version = "1"
def __init__(self):
"""Initialize Flash Attention module."""
super().__init__(op_id="flash_attention")
[docs] def create_shard_map_wrapper(
self,
query: Float[Array, "batch seq_len_q num_heads head_dim"],
key: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
value: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
attention_mask: Bool[Array, "batch num_heads_or_1 seq_len_q seq_len_k"]
| Int[Array, "batch num_heads_or_1 seq_len_q seq_len_k"]
| None = None,
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None,
softmax_scale: float | None = None,
dropout_prob: float = 0.0,
causal: bool = False,
dropout_seed: int | None = None,
cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None,
cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None,
sliding_window: int | tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None,
normalize_output: bool = True,
precision: lax.PrecisionLike = lax.Precision.DEFAULT,
logits_dtype: DTypeLike = jnp.float32,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
q_segment_ids: Int[Array, "batch seq_len_q"] | None = None,
kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None,
cfg: FlashAttentionConfig | None = None,
mesh: Mesh | None = None,
in_specs: tuple[PartitionSpec, ...] | None = None,
out_specs: PartitionSpec | None = None,
check_vma: bool = False,
):
"""Create a shard_map wrapper specifically for flash attention.
Args:
query, key, value: Input tensors to be sharded
mesh: JAX device mesh
in_specs: Input partition specs (for q, k, v, and optionally mask/bias)
out_specs: Output partition spec
All other args: Flash attention parameters to be fixed via partial
Returns:
Tuple of (shard_map_fn, call_args)
"""
assert mesh is not None, "mesh must be provided for shard_map execution"
assert in_specs is not None, "in_specs must be provided for shard_map execution"
assert out_specs is not None, "out_specs must be provided for shard_map execution"
def _wraped_flash_attn(
query: Float[Array, "batch num_heads seq_len head_dim"],
key: Float[Array, "batch kv_num_heads kv_len head_dim"],
value: Float[Array, "batch kv_num_heads kv_len vhead_dim"],
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None,
softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None,
cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None,
cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None,
attention_mask: Int[Array, "batch num_heads seq_len kv_len"] | None = None,
q_segment_ids: Int[Array, "batch seq_len_q"] | None = None,
kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None,
) -> Float[Array, "batch num_heads seq_len head_dim"]:
return self.run(
query=query,
key=key,
value=value,
bias=bias,
softmax_aux=softmax_aux,
cum_seqlens_k=cum_seqlens_k,
cum_seqlens_q=cum_seqlens_q,
attention_mask=attention_mask,
softmax_scale=softmax_scale,
dropout_prob=dropout_prob,
causal=causal,
dropout_seed=dropout_seed,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
normalize_output=normalize_output,
precision=precision,
logits_dtype=logits_dtype,
kv_segment_ids=kv_segment_ids,
q_segment_ids=q_segment_ids,
platform=platform,
cfg=cfg,
)
call_args = (
query,
key,
value,
bias,
softmax_aux,
cum_seqlens_q,
cum_seqlens_k,
attention_mask,
q_segment_ids,
kv_segment_ids,
)
assert len(in_specs) == len(call_args), f"in_specs length {len(in_specs)} != call_args length {len(call_args)}"
shard_map_fn = shard_map(
_wraped_flash_attn,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
)
return shard_map_fn, call_args
[docs] def get_impl(self, cfg: FlashAttentionConfig):
"""Get kernel implementation from registry based on configuration.
Args:
cfg: Configuration specifying platform and backend
Returns:
Callable kernel implementation
Raises:
ValueError: If no matching implementation is found
"""
return kernel_registry.get(
algorithm="flash_attention",
platform=detect_platform("flash_attention", cfg.platform),
backend=cfg.backend,
)
[docs] def run(
self,
query: Float[Array, "batch seq_len_q num_heads head_dim"],
key: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
value: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
attention_mask: Bool[Array, "batch num_heads_or_1 seq_len_q seq_len_k"]
| Int[Array, "batch num_heads_or_1 seq_len_q seq_len_k"]
| None = None,
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None,
softmax_scale: float | None = None,
dropout_prob: float = 0.0,
causal: bool = False,
dropout_seed: int | None = None,
cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None,
cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None,
sliding_window: int | tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None,
normalize_output: bool = True,
precision: lax.PrecisionLike = lax.Precision.DEFAULT,
logits_dtype: DTypeLike = jnp.float32,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
q_segment_ids: Int[Array, "batch seq_len_q"] | None = None,
kv_segment_ids: Int[Array, "batch seq_len_k"] | None = None,
cfg: FlashAttentionConfig,
) -> Float[Array, "batch seq_len_q num_heads head_dim"]:
"""Execute flash attention with the given configuration.
Args:
query: Query tensor [batch, seq_len_q, num_heads, head_dim]
key: Key tensor [batch, seq_len_k, num_heads, head_dim]
value: Value tensor [batch, seq_len_k, num_heads, head_dim]
attention_mask: Optional attention mask (legacy, prefer bias)
bias: Optional attention bias tensor
softmax_scale: Scaling factor for attention scores
dropout_prob: Dropout probability for attention weights
causal: Whether to apply causal masking
dropout_seed: Random seed for dropout
cum_seqlens_q: Cumulative sequence lengths for variable-length queries
cum_seqlens_k: Cumulative sequence lengths for variable-length keys
sliding_window: Window size for local attention
logits_soft_cap: Optional soft cap value for logits
softmax_aux: Optional attention sink logits
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
cfg: Configuration object specifying platform/backend
segment_ids: Segment IDs for grouped sequences (TPU-specific)
block_sizes: Block sizes for kernel execution (TPU-specific)
Returns:
Attention output [batch, seq_len_q, num_heads, head_dim]
"""
if platform is not None:
cfg = FlashAttentionConfig(
fwd_params=cfg.fwd_params,
bwd_params=cfg.bwd_params,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
out = impl(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
bias=bias,
softmax_scale=softmax_scale,
dropout_prob=dropout_prob,
causal=causal,
dropout_seed=dropout_seed,
cum_seqlens_q=cum_seqlens_q,
cum_seqlens_k=cum_seqlens_k,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
softmax_aux=softmax_aux,
normalize_output=normalize_output,
precision=precision,
logits_dtype=logits_dtype,
q_segment_ids=q_segment_ids,
kv_segment_ids=kv_segment_ids,
fwd_params=cfg.fwd_params,
bwd_params=cfg.bwd_params,
)
if q_segment_ids is not None:
q_valid = q_segment_ids >= 0
out = jnp.where(q_valid[:, :, None, None], out, 0)
return out
[docs] def heuristic_cfg_gpu(self, inv: Invocation[FlashAttentionConfig, Array]) -> FlashAttentionConfig:
"""Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
Args:
inv: Invocation object with arguments and metadata
Returns:
Default configuration with block sizes
"""
q = inv.kwargs["query"]
head_dim = int(q.shape[-1])
use_segments = (inv.kwargs.get("q_segment_ids") is not None) or (inv.kwargs.get("kv_segment_ids") is not None)
# Conservative defaults to avoid SMEM launch failures on GPUs with ~99KiB limit.
kv_block = 64 if (use_segments or head_dim >= 128) else 128
return FlashAttentionConfig(
fwd_params=FwdParams(q_blocksize=64, kv_blocksize=kv_block, num_warps=4, num_stages=2),
bwd_params=BwdParams(
q_blocksize=32,
kv_blocksize=32,
num_warps=4,
num_stages=2,
),
platform="triton",
backend="gpu",
)
[docs] def heuristic_cfg_tpu(self, inv: Invocation[FlashAttentionConfig, Array]) -> FlashAttentionConfig:
"""Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
Args:
inv: Invocation object with arguments and metadata
Returns:
Default configuration with block sizes
"""
return FlashAttentionConfig(
fwd_params=FwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=None,
num_stages=None,
),
bwd_params=BwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=None,
num_stages=None,
),
platform="pallas",
backend="tpu",
)
[docs] def heuristic_cfg(self, inv: Invocation[FlashAttentionConfig, Array]) -> FlashAttentionConfig:
"""Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
Args:
inv: Invocation object with arguments and metadata
Returns:
Default configuration with block sizes
"""
return FlashAttentionConfig(
fwd_params=FwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=None,
num_stages=None,
),
bwd_params=BwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=None,
num_stages=None,
),
platform="auto",
backend="any",
)
[docs] def candidate_cfgs(self, inv: Invocation[FlashAttentionConfig, Array]):
"""Generate candidate configurations for autotuning.
Creates multiple block size configurations for benchmarking to find
the optimal tiling parameters for the given input shapes.
Args:
inv: Invocation object with arguments and metadata
Returns:
Iterable of candidate configurations to test during autotuning
Note:
The autotuning system will benchmark each candidate and select
the fastest one for the given input configuration.
"""
block_configs = [
(128, 128),
(128, 256),
(256, 128),
(256, 256),
]
candidates = []
for chunk_q, chunk_k in block_configs:
candidates.append(
FlashAttentionConfig(
fwd_params=FwdParams(q_blocksize=chunk_q, kv_blocksize=chunk_k, num_warps=4, num_stages=2),
bwd_params=BwdParams(q_blocksize=chunk_q // 2, kv_blocksize=chunk_k // 2, num_warps=4, num_stages=2),
platform="auto",
backend="any",
)
)
return candidates
[docs] def candidate_cfgs_gpu(self, inv: Invocation[FlashAttentionConfig, Array]):
"""Generate GPU-optimized candidate configurations for autotuning (Triton).
Heuristics:
- q/kv blocks adapt to head_dim and sequence lengths.
- If sliding_window is set, kv blocks are capped near the window span.
- num_warps: 2-8 based on head_dim and block sizes.
- num_stages: 2-3 (kept low to reduce SMEM pressure).
- Conservative shared-memory guard to avoid CUDA errors.
- Backward blocks smaller to reduce register pressure.
"""
q = inv.kwargs["query"]
k = inv.kwargs["key"]
head_dim = int(q.shape[-1])
q_len = int(q.shape[1])
k_len = int(k.shape[1])
dtype = q.dtype
sliding_window = inv.kwargs.get("sliding_window", None)
causal = bool(inv.kwargs.get("causal", True))
def window_total(sw):
if sw is None:
return None
if isinstance(sw, int):
right = 0 if causal else sw
return sw + right + 1
wl, wr = sw
return wl + wr + 1
win = window_total(sliding_window)
smem_limit = int(os.getenv("EJKERNEL_TRITON_SMEM_LIMIT", str(99 * 1024)))
def next_pow2_ge(x: int, min_val: int = 16) -> int:
return max(min_val, 1 << math.ceil(math.log2(max(1, x))))
block_headdim = next_pow2_ge(head_dim, 16)
elem_bytes = 2 if dtype in (jnp.float16, jnp.bfloat16) else 4
def smem_est_bytes(qb: int, kb: int, num_stages: int) -> int:
kv_bytes = 2 * kb * block_headdim * elem_bytes
q_bytes = int(0.25 * qb * block_headdim * elem_bytes)
base = kv_bytes + q_bytes
stage_factor = 1.0 + 0.5 * max(0, num_stages - 2)
fudge = 2.5
return int(base * stage_factor * fudge)
if head_dim <= 64:
q_opts = [32, 64, 128]
elif head_dim <= 128:
q_opts = [32, 64, 128]
elif head_dim <= 192:
q_opts = [32, 64, 128]
else:
q_opts = [32, 64]
base_kv = [32, 64, 128, 256]
if win is not None:
target = max(32, min(256, 1 << (int(math.log2(max(32, win))) if win > 0 else 5)))
kv_opts = sorted(set([32, 64, min(128, target), min(256, target)]))
else:
kv_opts = base_kv
if k_len < 128:
kv_opts = [x for x in kv_opts if x <= 128] or [64, 128]
if q_len < 128:
q_opts = [x for x in q_opts if x <= 128] or [64, 128]
def pick_warps_stages(qb: int, kb: int, dh: int) -> tuple[int, int]:
if dh <= 64:
warps = 2 if max(qb, kb) <= 64 else 4
elif dh <= 128:
warps = 4 if max(qb, kb) <= 128 else 8
else:
warps = 8 if max(qb, kb) >= 128 else 4
stages = 3 if kb >= 128 else 2
return warps, stages
def bwd_block(x: int, cap: int = 128) -> int:
return max(32, min(cap, x // 2 if x >= 64 else x))
hv_pairs = []
preferred = [(64, 64), (128, 64), (64, 128), (128, 128)]
if win is not None:
preferred.insert(0, (64, min(128, max(64, win))))
preferred.insert(0, (32, min(128, max(64, win))))
for qb, kb in preferred:
if qb in q_opts and kb in kv_opts:
hv_pairs.append((qb, kb))
grid_pairs = []
for qb in q_opts:
for kb in kv_opts:
if (qb, kb) not in hv_pairs:
grid_pairs.append((qb, kb))
max_candidates = 18
pairs = []
seen = set()
for qb, kb in hv_pairs + grid_pairs:
if (qb, kb) in seen:
continue
w, s = pick_warps_stages(qb, kb, head_dim)
if smem_est_bytes(qb, kb, s) <= smem_limit:
seen.add((qb, kb))
pairs.append((qb, kb, w, s))
if len(pairs) >= max_candidates:
break
if not pairs:
qb, kb = 64, 64
w, s = pick_warps_stages(qb, kb, head_dim)
pairs = [(qb, kb, w, s)]
configs: list[FlashAttentionConfig] = []
for qb, kb, w, s in pairs:
configs.append(
FlashAttentionConfig(
fwd_params=FwdParams(q_blocksize=qb, kv_blocksize=kb, num_warps=w, num_stages=s),
bwd_params=BwdParams(q_blocksize=bwd_block(qb), kv_blocksize=bwd_block(kb)),
platform="triton",
backend="gpu",
)
)
return configs
[docs] def candidate_cfgs_tpu(self, inv: Invocation[FlashAttentionConfig, Array]):
"""Generate TPU-optimized candidate configurations for autotuning (Pallas).
Heuristics:
- Favor moderate Q blocks (32-128) and KV blocks (64-256/512).
- If sliding_window is set, prefer kv blocks ≲ window span.
- Slightly smaller backward blocks to reduce VMEM/regs.
- Keep the candidate list compact and ordered for fast convergence.
"""
q = inv.kwargs["query"]
k = inv.kwargs["key"]
q_len = int(q.shape[1])
k_len = int(k.shape[1])
sliding_window = inv.kwargs.get("sliding_window", None)
causal = bool(inv.kwargs.get("causal", True))
def win_span(sw):
if sw is None:
return None
if isinstance(sw, int):
right = 0 if causal else sw
return sw + right + 1
wl, wr = sw
return wl + wr + 1
def round128(x: int | float) -> int:
return 128 * max(1, round(float(x) / 128.0))
win = win_span(sliding_window)
q_opts = [128, 256]
kv_opts = [128, 256, 512]
if win is not None:
target = max(128, min(512, round128(win)))
kv_opts = sorted(set([*kv_opts, target, min(512, 2 * target)]))
if q_len < 256:
q_opts = [x for x in q_opts if x <= 256] or [128]
if k_len < 256:
kv_opts = [x for x in kv_opts if x <= 256] or [128, 256]
def bwd_tile(_x: int) -> int:
return 128
hv_pairs: list[tuple[int, int]] = []
if win is not None:
t1 = max(128, min(512, round128(win)))
hv_pairs += [(128, t1), (256, t1), (128, min(512, 2 * t1))]
hv_pairs += [(128, 128), (128, 256), (256, 256), (256, 512)]
selected: list[tuple[int, int]] = []
seen = set()
for qb, kb in hv_pairs:
if qb in q_opts and kb in kv_opts and (qb, kb) not in seen:
selected.append((qb, kb))
seen.add((qb, kb))
for qb in q_opts:
for kb in kv_opts:
if (qb, kb) not in seen:
selected.append((qb, kb))
seen.add((qb, kb))
if len(selected) >= 16:
break
if len(selected) >= 16:
break
configs: list[FlashAttentionConfig] = []
for qb, kb in selected:
configs.append(
FlashAttentionConfig(
fwd_params=FwdParams(q_blocksize=qb, kv_blocksize=kb, num_warps=None, num_stages=None),
bwd_params=BwdParams(
q_blocksize=bwd_tile(qb), kv_blocksize=bwd_tile(kb), num_warps=None, num_stages=None
),
platform="pallas",
backend="tpu",
)
)
return configs
[docs] def candidate_cfgs_xla(self, inv: Invocation[FlashAttentionConfig, Array]):
"""Generate XLA-optimized candidate configurations for autotuning.
Heuristics:
- Medium blocks (128-256) tend to be robust.
- If sliding_window is set, keep kv blocks near window span.
- Backward tiles are smaller.
- Keep list small and ordered by likely winners.
"""
q = inv.kwargs["query"]
k = inv.kwargs["key"]
q_len = int(q.shape[1])
k_len = int(k.shape[1])
sliding_window = inv.kwargs.get("sliding_window", None)
causal = bool(inv.kwargs.get("causal", True))
def win_span(sw):
if sw is None:
return None
if isinstance(sw, int):
right = 0 if causal else sw
return sw + right + 1
wl, wr = sw
return wl + wr + 1
def round128(x: int | float) -> int:
return 128 * max(1, round(float(x) / 128.0))
win = win_span(sliding_window)
q_opts = [128, 256]
kv_opts = [128, 256, 512]
if win is not None:
target = max(128, min(512, round128(win)))
kv_opts = sorted(set([*kv_opts, target, min(512, 2 * target)]))
if q_len < 256:
q_opts = [x for x in q_opts if x <= 256] or [128]
if k_len < 256:
kv_opts = [x for x in kv_opts if x <= 256] or [128, 256]
def bwd_tile(_x: int) -> int:
return 128
hv_pairs: list[tuple[int, int]] = []
if win is not None:
t1 = max(128, min(512, round128(win)))
hv_pairs += [(128, t1), (256, t1)]
hv_pairs += [(128, 128), (128, 256), (256, 256), (256, 128)]
selected: list[tuple[int, int]] = []
seen = set()
for qb, kb in hv_pairs:
if qb in q_opts and kb in kv_opts and (qb, kb) not in seen:
selected.append((qb, kb))
seen.add((qb, kb))
for qb in q_opts:
for kb in kv_opts:
if (qb, kb) not in seen:
selected.append((qb, kb))
seen.add((qb, kb))
if len(selected) >= 12:
break
if len(selected) >= 12:
break
configs: list[FlashAttentionConfig] = []
for qb, kb in selected:
configs.append(
FlashAttentionConfig(
fwd_params=FwdParams(q_blocksize=qb, kv_blocksize=kb, num_warps=None, num_stages=None),
bwd_params=BwdParams(
q_blocksize=bwd_tile(qb), kv_blocksize=bwd_tile(kb), num_warps=None, num_stages=None
),
platform="xla",
backend="any",
)
)
return configs
candidate_cfgs_shard_map_gpu = candidate_cfgs_gpu
candidate_cfgs_shard_map_tpu = candidate_cfgs_tpu
candidate_cfgs_shard_map_xla = candidate_cfgs_xla
_flash_executor: Executor[FlashAttentionConfig, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(
allow_autotune=True,
cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"),
validate_backward=True,
),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("flash-attn", cfg_type=FlashAttentionConfig),
)
)
[docs]def flash_attention(
query: Float[Array, "batch seq_len_q num_heads head_dim"],
key: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
value: Float[Array, "batch seq_len_k num_kv_heads head_dim"],
bias: Float[Array, "batch num_heads seq_len_q seq_len_k"] | None = None,
cum_seqlens_q: Int[Array, "batch_plus_one"] | None = None,
cum_seqlens_k: Int[Array, "batch_plus_one"] | None = None,
softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None,
/,
*,
mask_info: MaskInfo | None = None,
softmax_scale: float | None = None,
dropout_prob: float = 0.0,
causal: bool = False,
dropout_seed: int | None = None,
sliding_window: int | tuple[int, int] | None = None,
logits_soft_cap: float | None = None,
normalize_output: bool = True,
precision: lax.PrecisionLike = lax.Precision.DEFAULT,
logits_dtype: DTypeLike = jnp.float32,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: FlashAttentionConfig | None = None,
mesh: Mesh | None = None,
in_specs: tuple[PartitionSpec | None, ...] | None = None,
out_specs: PartitionSpec | None = None,
) -> Float[Array, "batch seq_len_q num_heads head_dim"]:
"""Execute flash attention with automatic optimization.
Convenience function that uses a default executor and flash attention module.
Args:
query: Query tensor [batch, seq_len, num_heads, head_dim]
key: Key tensor [batch, seq_len_k, num_heads, head_dim]
value: Value tensor [batch, seq_len_k, num_heads, head_dim]
mask_info: Optional MaskInfo containing attention mask and/or segment IDs
bias: Optional attention bias tensor
softmax_scale: Scaling factor for attention scores (default: 1/sqrt(head_dim))
dropout_prob: Dropout probability for attention weights
causal: Whether to apply causal masking
dropout_seed: Random seed for dropout
cum_seqlens_q: Cumulative sequence lengths for variable-length queries
cum_seqlens_k: Cumulative sequence lengths for variable-length keys
sliding_window: Window size for local attention (int or (left, right) tuple)
logits_soft_cap: Optional soft cap value for logits
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
cfg: Optional configuration override
mesh: JAX device mesh for shard_map execution (optional)
in_specs: Input partition specs for shard_map (optional)
out_specs: Output partition spec for shard_map (optional)
Returns:
Attention output with same shape as query
Example:
>>>
>>> out = flash_attention(query, key, value, causal=True)
>>>
>>>
>>> out = flash_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125)
>>>
>>>
>>> out = flash_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k)
>>>
>>>
>>> out = flash_attention(query, key, value, platform="triton")
"""
attention_mask = None
q_segment_ids = None
kv_segment_ids = None
if mask_info is not None:
attention_mask = mask_info._attention_mask
if mask_info._q_segment_ids is not None or mask_info._kv_segment_ids is not None:
q_segment_ids, kv_segment_ids = mask_info.get_or_compute_segment_ids(per_head=False)
elif attention_mask is None:
attention_mask = mask_info.get_or_compute_attention_mask()
method = None
if mesh is not None and in_specs is not None and out_specs is not None:
method = "shard_map"
if mask_info is None:
in_specs = (*in_specs, None, None, None)
else:
shardings = mask_info.get_shardings(False, mesh=mesh)
in_specs = (
*in_specs,
shardings.attention_mask if attention_mask is not None else None,
shardings.q_segment_ids if q_segment_ids is not None else None,
shardings.kv_segment_ids if kv_segment_ids is not None else None,
)
return _flash_executor(
FlashAttention(),
query=query,
key=key,
value=value,
attention_mask=attention_mask,
bias=bias,
softmax_scale=softmax_scale,
dropout_prob=dropout_prob,
causal=causal,
dropout_seed=dropout_seed,
cum_seqlens_q=cum_seqlens_q,
cum_seqlens_k=cum_seqlens_k,
sliding_window=sliding_window,
logits_soft_cap=logits_soft_cap,
softmax_aux=softmax_aux,
normalize_output=normalize_output,
precision=precision,
logits_dtype=logits_dtype,
q_segment_ids=q_segment_ids,
kv_segment_ids=kv_segment_ids,
platform=platform,
method=method,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
_cfg=cfg,
)