ejkernel.kernels._pallas.tpu.flash_attention._utils#
Flash Attention TPU kernel.
- class ejkernel.kernels._pallas.tpu.flash_attention._utils.BlockSizes(block_q: int, block_k_major: int, block_k: int, block_b: int, block_q_major_dkv: int | None = None, block_k_major_dkv: int | None = None, block_k_dkv: int | None = None, block_q_dkv: int | None = None, block_k_major_dq: int | None = None, block_k_dq: int | None = None, block_q_dq: int | None = None)[source]#
Bases:
objectTile sizes parameterizing FlashAttention kernels.
Those parameters have negligible effect on numerics, but affect performance greatly.
- block_b: int#
- block_k: int#
- block_k_dkv: int | None = None#
- block_k_dq: int | None = None#
- block_k_major: int#
- block_k_major_dkv: int | None = None#
- block_k_major_dq: int | None = None#
- block_q: int#
- block_q_dkv: int | None = None#
- block_q_dq: int | None = None#
- block_q_major_dkv: int | None = None#
- property has_backward_blocks: bool#
- class ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds(q: jax.Array, kv: jax.Array)[source]#
Bases:
NamedTupleSegmentIds for Q and KV sequences.
SegmentIds are used to generate segment mask, which prevents attention between different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other.
- q#
segment ids along the Q sequence.
- Type
jax.jaxlib._jax.Array
- kv#
segment ids along the KV sequence.
- Type
jax.jaxlib._jax.Array
- kv: Array#
Alias for field number 1
- q: Array#
Alias for field number 0
- ejkernel.kernels._pallas.tpu.flash_attention._utils.below_or_on_diag(r, r_blk_size, c, c_blk_size)[source]#
- ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference(q, k, v, ab, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None = None, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale=1.0)[source]#
- ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference_bwd(q, k, v, ab, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None, o, l, m, do, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale: float = 1.0)[source]#
- ejkernel.kernels._pallas.tpu.flash_attention._utils.mha_reference_no_custom_vjp(q, k, v, ab: jax.jaxlib._jax.Array | None = None, segment_ids: ejkernel.kernels._pallas.tpu.flash_attention._utils.SegmentIds | None = None, *, causal: bool = False, mask_value: float = -2.381976426469702e+38, softmax_scale: float = 1.0, save_residuals: bool = False)[source]#