Source code for ejkernel.kernels._pallas.tpu.flash_attention._utils

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Flash Attention TPU kernel."""

from __future__ import annotations

import dataclasses
import functools
import math
from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
NUM_LANES = 128
NUM_SUBLANES = 8
MIN_BLOCK_SIZE = 128
TRANS_B_DIM_NUMBERS = (((1,), (1,)), ((), ()))


[docs]class SegmentIds(NamedTuple): """SegmentIds for Q and KV sequences. SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other. Attributes: q: segment ids along the Q sequence. kv: segment ids along the KV sequence. """ q: jax.Array kv: jax.Array
[docs]@dataclasses.dataclass(frozen=True) class BlockSizes: """Tile sizes parameterizing FlashAttention kernels. Those parameters have negligible effect on numerics, but affect performance greatly. """ block_q: int block_k_major: int block_k: int block_b: int block_q_major_dkv: int | None = None block_k_major_dkv: int | None = None block_k_dkv: int | None = None block_q_dkv: int | None = None block_k_major_dq: int | None = None block_k_dq: int | None = None block_q_dq: int | None = None def __post_init__(self): def verify_major_minor(prefix, suffix, major, minor): if minor > major: raise ValueError(f"{prefix}{suffix}={minor} should be smaller than {prefix}_major{suffix}={major}") if major % minor != 0: raise ValueError(f"{prefix}{suffix}={minor} should divide {prefix}_major{suffix}={major}") verify_major_minor("block_k", "", self.block_k_major, self.block_k) if self.block_q_major_dkv is not None and self.block_q_dkv is not None: verify_major_minor("block_q", "_dkv", self.block_q_major_dkv, self.block_q_dkv) if self.block_k_major_dkv is not None and self.block_k_dkv is not None: verify_major_minor("block_k", "_dkv", self.block_k_major_dkv, self.block_k_dkv) if self.block_k_major_dq is not None and self.block_k_dq is not None: verify_major_minor("block_k", "_dq", self.block_k_major_dq, self.block_k_dq) @property def has_backward_blocks(self) -> bool: backward_blocks = ( self.block_q_major_dkv, self.block_k_major_dkv, self.block_q_dkv, self.block_k_dkv, self.block_k_major_dq, self.block_k_dq, self.block_q_dq, ) return all(b is not None for b in backward_blocks)
[docs] @classmethod def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model): del batch_size, num_heads, q_seq_len, kv_len, d_model return BlockSizes( block_q=128, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=128, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=128, block_k_major_dq=128, block_k_dq=128, block_q_dq=128, )
def _verify_block(block_name, dim_name, block, dim, should_divide=True): if block > dim: raise ValueError(f"{block_name}={block} should be smaller or equal to {dim_name}={dim}") if should_divide and dim % block != 0: raise ValueError(f"{dim_name}={dim} should be divisible by {block_name}={block}") def _bytes(x: jax.Array | jax.ShapeDtypeStruct) -> int: return math.prod(x.shape) * x.dtype.itemsize def _fwd_cost_estimate( q: jax.Array, k: jax.Array, v: jax.Array, ab: jax.Array | None, segment_ids: SegmentIds | None, *, causal: bool, softmax_scale: jax.Array | None, kernel_inputs_specs, kernel_outputs_specs, ) -> pl.CostEstimate | None: body_cost = pl.estimate_cost(mha_reference, q, k, v, ab, segment_ids, causal=causal, softmax_scale=softmax_scale) input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) return pl.CostEstimate( flops=body_cost.flops, transcendentals=body_cost.transcendentals, bytes_accessed=input_bytes + output_bytes, )
[docs]def mha_reference_no_custom_vjp( q, k, v, ab: jax.Array | None = None, segment_ids: SegmentIds | None = None, *, causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, softmax_scale: float = 1.0, save_residuals: bool = False, ): logits = jnp.einsum("bhqc,bhkc->bhqk", q, k) if ab is not None: logits += ab if softmax_scale != 1.0: logits *= softmax_scale mask = None if segment_ids is not None: mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :] mask = mask[:, None, :, :] if causal: _, _, q_seq_len, _ = q.shape _, _, kv_seq_len, _ = k.shape mask_shape = (q_seq_len, kv_seq_len) row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) causal_mask = (col_ids <= row_ids)[None, None, :, :] mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) logits = logits if mask is None else logits + jnp.where(mask, 0.0, mask_value) m = logits.max(axis=-1) unnormalized = jnp.exp(logits - m[..., None]) l = unnormalized.sum(axis=-1) weights = unnormalized / l[..., None] out = jnp.einsum("bhqk,bhkc->bhqc", weights, v) if save_residuals: return out, l, m return out
[docs]@functools.partial(jax.jit, static_argnames=["causal", "mask_value", "softmax_scale"]) @jax.default_matmul_precision("bfloat16") def mha_reference( q, k, v, ab, segment_ids: SegmentIds | None = None, causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, softmax_scale=1.0, ): return _mha_reference( q, k, v, ab, segment_ids, causal=causal, mask_value=mask_value, softmax_scale=softmax_scale, save_residuals=False, )
@functools.partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8)) def _mha_reference( q, k, v, ab, segment_ids: SegmentIds | None, causal: bool, mask_value: float, softmax_scale: float, save_residuals: bool, ): return mha_reference_no_custom_vjp( q, k, v, ab, segment_ids, causal=causal, mask_value=mask_value, softmax_scale=softmax_scale, save_residuals=save_residuals, ) def _mha_reference_fwd( q, k, v, ab, segment_ids: SegmentIds | None, causal: bool, mask_value: float, softmax_scale: float, save_residuals: bool, ): if save_residuals: raise NotImplementedError res = _mha_reference( q, k, v, ab, segment_ids, causal=causal, mask_value=mask_value, softmax_scale=softmax_scale, save_residuals=True, ) assert isinstance(res, tuple) out, l, m = res return out, (q, k, v, ab, segment_ids, out, l, m)
[docs]@functools.partial( jax.jit, static_argnames=[ "causal", "mask_value", "softmax_scale", ], ) def mha_reference_bwd( q, k, v, ab, segment_ids: SegmentIds | None, o, l, m, do, causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, softmax_scale: float = 1.0, ): if softmax_scale != 1.0: raise NotImplementedError logits = jnp.einsum( "bhqc,bhkc->bhqk", q.astype(jnp.float32), k.astype(jnp.float32), ) if ab is not None: logits += ab mask = None if segment_ids is not None: mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :] mask = mask[:, None, :, :] if causal: _, _, q_seq_len, _ = q.shape _, _, kv_seq_len, _ = k.shape mask_shape = (q_seq_len, kv_seq_len) row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) causal_mask = (col_ids <= row_ids)[None, None, :, :] mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) logits = logits if mask is None else logits + jnp.where(mask, 0.0, mask_value) unnormalized = jnp.exp(logits - m[..., None]) p = unnormalized / l[..., None] dv = jnp.einsum("bhpt,bhpd->bhtd", p, do.astype(jnp.float32)).astype(v.dtype) dp = jnp.einsum("bhpd,bhtd->bhpt", do.astype(jnp.float32), v.astype(jnp.float32)) di = jnp.sum(o.astype(jnp.float32) * do.astype(jnp.float32), axis=-1)[..., None] ds = (dp - di) * p dk = jnp.einsum("bhsd,bhst->bhtd", q.astype(jnp.float32), ds).astype(k.dtype) dq = jnp.einsum("bhst,bhtd->bhsd", ds, k.astype(jnp.float32)).astype(q.dtype) dab = ds if ab is not None else None return dq, dk, dv, dab
def _mha_reference_bwd( causal: bool, mask_value: float, softmax_scale: float, save_residuals: bool, residuals, do, ): del save_residuals q, k, v, ab, segment_ids, o, l, m = residuals dq, dk, dv, dab = mha_reference_bwd( q, k, v, ab, segment_ids, o, l, m, do, causal=causal, mask_value=mask_value, softmax_scale=softmax_scale, ) return dq, dk, dv, dab, None _mha_reference.defvjp(fwd=_mha_reference_fwd, bwd=_mha_reference_bwd)
[docs]def below_or_on_diag(r, r_blk_size, c, c_blk_size): return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)