ejkernel.kernels._pallas.tpu.flash_attention._utils#

Flash Attention TPU kernel.

class ejkernel.kernels._pallas.tpu.flash_attention._utils.BlockSizes(block_q: int, block_k_major: int, block_k: int, block_b: int, block_q_major_dkv: int | None = None, block_k_major_dkv: int | None = None, block_k_dkv: int | None = None, block_q_dkv: int | None = None, block_k_major_dq: int | None = None, block_k_dq: int | None = None, block_q_dq: int | None = None)[source]#

Bases: object

Tile sizes parameterizing FlashAttention kernels.

Those parameters have negligible effect on numerics, but affect performance greatly.

block_b: int#
block_k: int#
block_k_dkv: int | None = None#
block_k_dq: int | None = None#
block_k_major: int#
block_k_major_dkv: int | None = None#
block_k_major_dq: int | None = None#
block_q: int#
block_q_dkv: int | None = None#
block_q_dq: int | None = None#
block_q_major_dkv: int | None = None#
classmethod get_default(batch_size, num_heads, q_seq_len, kv_len, d_model)[source]#
property has_backward_blocks: bool#
class ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds(q: jax.Array, kv: jax.Array)[source]#

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.

q#

segment ids along the Q sequence.

Type

jax.jaxlib._jax.Array

kv#

segment ids along the KV sequence.

Type

jax.jaxlib._jax.Array

kv: Array#

Alias for field number 1

q: Array#

Alias for field number 0

ejkernel.kernels._pallas.tpu.flash_attention._utils.below_or_on_diag(r, r_blk_size, c, c_blk_size)[source]#
ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference(q, k, v, ab, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None = None, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale=1.0)[source]#
ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference_bwd(q, k, v, ab, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None, o, l, m, do, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale: float = 1.0)[source]#
ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference_no_custom_vjp(q, k, v, ab: jax.jaxlib._jax.Array | None = None, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None = None, *, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale: float = 1.0, save_residuals: bool = False)[source]#