Source code for ejkernel.kernels._pallas.tpu.blocksparse_attention._info

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Mini-mask creation library."""

from __future__ import annotations

import collections
import functools
import math
from collections.abc import Callable
from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
from jax._src import util as jax_util

from . import _masks as mask_lib

# mypy: ignore-errors


[docs]class MaskInfo(NamedTuple): """Contains runtime masking information for the Splash attention kernel. The arrays data_next, mask_next and block_mask are placed in TPU scalar-memory. This is a scarse resource so the mask creation logic attempts to shrink the data-type of these arrays to the smallest possible one. This can be: np.int32, np.int16 or np.int8. For the arrays data_next, mask_next and block_mask the size of the first dimension can be one of the two following values: num_head or num_head_shards. The first dimension has size: * num_head_shards when there is only one unique mask for each head in a shard. In this case the three arrays are broadcasted to all the heads in the shard. * num_heads when there is more than one unique mask for each head in the shard. Attributes: data_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next `kv` block index to prefetch. mask_next: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next mask block index in `partial_mask_blocks` to prefetch. block_mask: An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that the corresponding block in the full mask was all zeros. An entry of 1 indicates that the corresponding block in the full mask contained both zeros and ones. An entry of 2 indicates the corresponding block was entirely ones. partial_mask_blocks: A bool[num_partial_blocks, block_q, block_kv] NumPy array that contains the blocks of the original mask that contained both zeros and ones. The entries in `mask_next` point to indices in the first axis of this array. q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). is_dynamic_mask: A bool indicating whether the mask is dynamic or static. When True, the leading dimensions of `partial_mask_blocks` (num_heads, q_blocks, kv_blocks) are not collapsed, allowing us to shard it along those dimensions. """ data_next: np.ndarray | jax.Array | None mask_next: np.ndarray | jax.Array | None block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None is_dynamic_mask: bool = None
def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: """Downcast numpy array. If possible, downcast the data-type of the input array to the smallest numpy type (among np.int16 and np.int8) that fits the content of the array. Args: array: the array to downcast Returns: The downcasted array. Raises: ValueError: if the input array is not np.int32 or if its elements are not all positive. """ if array.dtype != np.int32: raise ValueError("Expected int32 input.") if not np.all(array >= 0): raise ValueError("Expected non-negative array.") if array.size == 0: return array max_value = np.max(array) if max_value <= np.iinfo(np.int8).max: return array.astype(np.int8) elif max_value <= np.iinfo(np.int16).max: return array.astype(np.int16) else: return array.astype(np.int32) def _check_mask(mask: mask_lib.Mask) -> None: """Check that the given mask is valid. A row of all zeros along the kv dimension would result in a division by zero when computing the softmax. This function is meant to protect against that case. Args: mask: the mask to check. Raises: ValueError: the mask is invalid. """ assert len(mask.shape) == 2 exception_message = ( "Some rows of the mask (along the kv dimension) are all zeros.\nThis is" " would result in a division by zero when computing the attention" " softmax." ) is_row_non_zero = np.zeros(mask.shape[0], dtype=np.bool_) for col in range(mask.shape[1]): is_row_non_zero = np.logical_or( is_row_non_zero, mask[(slice(0, mask.shape[0]), slice(col, col + 1))][:, 0], ) if not is_row_non_zero.all(): raise ValueError(exception_message) class _HashableNDArray: """Helper to make a numpy array hashable: can be added associative containers. Attributes: array: The underlying numpy array. """ array: np.ndarray def __init__(self, array: np.ndarray): self.array = array def __hash__(self): return hash(self.array.tobytes()) def __eq__(self, other: object) -> bool: if not isinstance(other, _HashableNDArray): return NotImplemented return np.array_equal(self.array, other.array, equal_nan=True) def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, head_start: int, num_heads: int, q_seq_start: int, q_seq_shard_size: int, blocked_q_seq_start: int, is_dkv: bool, ) -> tuple[np.ndarray, np.ndarray | None]: """Process a slice of the mask to compute data_next and mask_next. Args: output_shape: The shape of the data_next and mask_next to return has_mask_next: Whether mask_next should be constructed. If False None is returned for mask_next. mask: The full mask to be sliced according to the head and sequence ranges block_shape: Shape of the Pallas grid block. coords_to_partial_mask_block_index: Mapping between the pallas launch grid coordinates and the index of the corresponding block in partial mask block list. masks_per_head_shard: Number of masks per head shards head_start: First head of the current shard. num_heads: Number of heads in the shard. q_seq_start: Start index along the Q sequence for the current shard (in number of tokens). q_seq_shard_size: Number of tokens along the Q sequence for the current shard. blocked_q_seq_start: Start index along the Q sequence for the current shard (in number of grid blocks) is_dkv: True if we are processing the dKV mask Returns: Slice of data_next and mask_next (if required) that correspond to the current mask slice. """ _, _, kv_seq_len = mask.shape q_block_size, kv_blocksize = block_shape q_block_count, q_mod = divmod(q_seq_shard_size, q_block_size) kv_block_count, kv_mod = divmod(kv_seq_len, kv_blocksize) assert q_mod == 0 assert kv_mod == 0 blocked_shape = (kv_block_count, num_heads, q_block_count) if is_dkv else (num_heads, q_block_count, kv_block_count) data_coords = [] mask_coords = [] for idx in np.ndindex(blocked_shape): if is_dkv: kv_index, h_index, q_index = idx else: h_index, q_index, kv_index = idx h_index = h_index if masks_per_head_shard == 1 else head_start + h_index chunk = mask[ ( h_index, slice( q_seq_start + q_index * q_block_size, q_seq_start + (q_index + 1) * q_block_size, ), slice(kv_index * kv_blocksize, (kv_index + 1) * kv_blocksize), ) ] if chunk.any(): data_coords.append(idx) if not chunk.all(): mask_coords.append(idx) mask_next = None if has_mask_next: mask_next = np.zeros(output_shape, dtype=np.int32) data_next = np.zeros(output_shape, dtype=np.int32) if not data_coords: return data_next, mask_next data_coords_iter = iter(data_coords) first_j = coord_j = next(data_coords_iter) if mask_next is not None and mask_coords: mask_coords_iter = iter(mask_coords) first_m = coord_m = next(mask_coords_iter) else: first_m, coord_m, mask_coords_iter = None, None, None for idx in np.ndindex(blocked_shape): if is_dkv: kv_index, h_index, q_index = idx chunk_idx: tuple[int, ...] = (h_index, q_index, kv_index) data_dim = 2 else: chunk_idx = idx data_dim = 2 is_next = idx > coord_j if is_next: try: coord_j = next(data_coords_iter) except StopIteration: coord_j = first_j data_next[chunk_idx] = coord_j[data_dim] if mask_next is not None and mask_coords: assert coord_m is not None is_next_mask = idx > coord_m if is_next_mask: try: coord_m = next(mask_coords_iter) # type: ignore except StopIteration: coord_m = first_m if is_dkv: assert coord_m is not None coord_m_global = ( coord_m[1] + head_start, coord_m[2] + blocked_q_seq_start, coord_m[0], ) else: assert coord_m is not None coord_m_global = ( coord_m[0] + head_start, coord_m[1] + blocked_q_seq_start, coord_m[2], ) mask_next[chunk_idx] = coords_to_partial_mask_block_index[coord_m_global] return data_next, mask_next def _process_dynamic_mask( mask: jax.Array, block_shape: tuple[int, int], is_dkv: bool, *, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True, ) -> tuple[MaskInfo, None]: """Similar to `_process_mask` but the mask must be a dynamic array. Since the mask is dynamic, we can't know the exact number of partial mask blocks at trace time. Therefore, the entire mask is materialized in `partial_mask_blocks`. Note that we can still populate MaskInfo to skip fully-masked blocks. Args: mask: A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense mask to process. block_shape: A Tuple[int, int] representing the shape of the Pallas grid block. is_dkv: True if we are processing the dKV mask downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible). head_shards: Number of head shards of the mesh in which the kernel is launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. Raises: ValueError: if the input mask is invalid or the block sizes are not compatible with the mask sizes. """ del shrink_grid if len(mask.shape) != 3: raise ValueError(f"Expected a 3-dim mask, instead got: {mask.shape}.") if mask.dtype != jnp.bool: raise ValueError(f"Expected a bool mask, instead got: {mask.dtype}.") head_count, q_seq_len, kv_seq_len = mask.shape q_block_size, kv_blocksize = block_shape q_blocks_count, q_mod = divmod(q_seq_len, q_block_size) kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_blocksize) if q_mod != 0: raise ValueError(f"{q_block_size=} should divide {q_seq_len=}.") if kv_mod != 0: raise ValueError(f"{kv_blocksize=} should divide {kv_seq_len=}.") q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) if mod != 0: raise ValueError(f"{q_seq_shards=} should divide {q_seq_len=}.") q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) if mod != 0: raise ValueError(f"{q_block_size=} should divide {q_seq_len_per_shard=}.") heads_per_shard, mod = divmod(head_count, head_shards) if mod != 0: raise ValueError(f"{head_shards=} should divide {head_count=}.") block_mask_shape = ( head_count, q_blocks_count, kv_blocks_count, ) partial_mask_blocks = ( mask.reshape( head_count, q_blocks_count, q_block_size, kv_blocks_count, kv_blocksize, ) .swapaxes(-2, -3) .astype(np.bool_) ) is_full_mask = jnp.all(partial_mask_blocks, axis=(-1, -2)) is_empty_mask = jnp.logical_not(jnp.any(partial_mask_blocks, axis=(-1, -2))) block_mask = jnp.ones(block_mask_shape, dtype=np.int32) block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) q_sequence_axis = 1 head_axis = 0 mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(head_shards): head_start = head_shard * heads_per_shard mask_head_slice = slice(head_start, head_start + heads_per_shard) data_next_sequence_slices, mask_next_sequence_slices = [], [] for q_seq_len_shard in range(q_seq_shards): q_seq_len_start = q_seq_len_shard * q_blocks_per_shard blocked_q_seq_len_slice = slice(q_seq_len_start, q_seq_len_start + q_blocks_per_shard) local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] mask_next_slice = jnp.arange(math.prod(mask_info_slice_shape), dtype=np.int32).reshape(mask_info_slice_shape) mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) if is_dkv: data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[None, :, None] else: data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) data_next_sequence_slices.append(data_next_slice) mask_next_sequence_slices.append(mask_next_slice) data_next_per_head = jnp.concatenate(data_next_sequence_slices, axis=q_sequence_axis) data_next_per_head_list.append(data_next_per_head) mask_next_per_head = jnp.concatenate(mask_next_sequence_slices, axis=q_sequence_axis) mask_next_per_head_list.append(mask_next_per_head) data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) def _downcast(array: jax.Array, max_value: int) -> jax.Array: if array.size == 0: return array if array.dtype != np.int32: raise ValueError(f"Expected int32 input, but got {array.dtype}.") if max_value <= np.iinfo(np.int8).max: return array.astype(np.int8) elif max_value <= np.iinfo(np.int16).max: return array.astype(np.int16) else: return array.astype(np.int32) if downcast_smem_data: block_mask = block_mask.astype(np.int8) data_next = _downcast(data_next, q_blocks_per_shard if is_dkv else kv_blocks_count) mask_next = _downcast(mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count) return ( MaskInfo( data_next=data_next, mask_next=mask_next, block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, is_dynamic_mask=True, ), None, ) @functools.lru_cache(maxsize=12) def _process_mask( mask: mask_lib.MultiHeadMask, block_shape: tuple[int, int], is_dkv: bool, *, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True, ) -> tuple[MaskInfo, jax_util.HashableFunction | None]: """Transform a dense mask into a sparse representation. The number of head and Q sequence shards are needed to create a MaskInfo object that is partitionable (with shmap or PartIR) along these two dimension. In particular for dKV MaskInfo, for each shard the indices of in the data_next array are relative to the current shard. The fwd and dQ MaskInfo objects do not change when sharding along the head or Q dimensions, they would be different if we were to shard along the KV dimension, but the kernel does not support that. Args: mask: Dense mask to process. block_shape: Shape of the Pallas grid block. is_dkv: True if we are processing the dKV mask downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible). head_shards: Number of head shards of the mesh in which the kernel is launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. shrink_grid: Whether or not we should apply the grid shrinking optimization. Returns: `MaskInfo`, a sparse representation of the dense mask. `MaskCallable`: a callable that, given in input Q and KV indices, returns the value of the mask at those coordinates. Raises: ValueError: if the input mask is invalid or the block sizes are not compatible with the mask sizes. """ if len(mask.shape) != 3: raise ValueError(f"Expected a 3-dim mask, instead got: {mask.shape=}") head_count, q_seq_len, kv_seq_len = mask.shape q_block_size, kv_blocksize = block_shape q_blocks_count, q_mod = divmod(q_seq_len, q_block_size) kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_blocksize) if q_mod != 0: raise ValueError(f"{q_block_size=} should divide {q_seq_len=}.") if kv_mod != 0: raise ValueError(f"{kv_blocksize=} should divide {kv_seq_len=}.") q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) if mod != 0: raise ValueError(f"{q_seq_shards=} should divide {q_seq_len=}.") q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) if mod != 0: raise ValueError(f"{q_block_size=} should divide {q_seq_len_per_shard=}.") heads_per_shard, mod = divmod(head_count, head_shards) if mod != 0: raise ValueError(f"{head_shards=} should divide {head_count=}.") def assign_unique_ids(objects): id_map = collections.defaultdict(lambda: len(id_map)) return {obj: id_map[obj] for obj in objects} unique_masks_dict: dict[mask_lib.Mask, int] = assign_unique_ids(head_mask for head_mask in mask.masks) head_to_mask_id: list[int] = [0] * head_count head_shard_to_mask_ids: list[set[int]] = [set() for _ in range(head_shards)] mask_id_to_heads: list[list[int]] = [[] for _ in range(len(unique_masks_dict))] mask_id_to_head_shards: list[set[int]] = [set() for _ in range(len(unique_masks_dict))] for head in range(head_count): mask_id = unique_masks_dict[mask.masks[head]] head_to_mask_id[head] = mask_id head_shard = head // heads_per_shard head_shard_to_mask_ids[head_shard].add(mask_id) mask_id_to_heads[mask_id].append(head) mask_id_to_head_shards[mask_id].add(head_shard) max_masks_per_head_shard = max(len(x) for x in head_shard_to_mask_ids) masks_per_head_shard = 1 if max_masks_per_head_shard == 1 else heads_per_shard unique_masks = [pair[0] for pair in sorted(unique_masks_dict.items(), key=lambda x: x[1])] partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict(lambda: len(partial_mask_block_ids)) block_id_to_block_coords: dict[int, list[tuple[int, ...]]] = collections.defaultdict(list) block_mask_shape = ( head_shards if masks_per_head_shard == 1 else head_count, q_blocks_count, kv_blocks_count, ) block_mask = np.zeros(block_mask_shape, dtype=np.int32) def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): if masks_per_head_shard == 1: for shard_index in mask_id_to_head_shards[mask_id]: block_mask[shard_index, q_index, kv_index] = value else: for head_index in mask_id_to_heads[mask_id]: block_mask[head_index, q_index, kv_index] = value q_sequence = None mask_function = None if len(unique_masks) == 1: unique_mask = unique_masks[0] assert hasattr(unique_mask, "q_sequence") == hasattr(unique_mask, "mask_function") if hasattr(unique_mask, "q_sequence") and hasattr(unique_mask, "mask_function"): q_sequence = unique_mask.q_sequence mask_function = unique_mask.mask_function for mask_id, unique_mask in enumerate(unique_masks): for coords in np.ndindex((q_blocks_count, kv_blocks_count)): (q_index, kv_index) = coords chunk = unique_mask[ ( slice(q_index * q_block_size, (q_index + 1) * q_block_size), slice(kv_index * kv_blocksize, (kv_index + 1) * kv_blocksize), ) ] has_nonzero = chunk.any() if has_nonzero: all_nonzero = chunk.all() if not all_nonzero: set_block_mask(mask_id, q_index, kv_index, 1) partial_mask_block_id = partial_mask_block_ids[_HashableNDArray(chunk)] for head_index in mask_id_to_heads[mask_id]: block_id_to_block_coords[partial_mask_block_id].append((head_index, *coords)) else: set_block_mask(mask_id, q_index, kv_index, 2) unique_partial_mask_blocks = [pair[0] for pair in sorted(partial_mask_block_ids.items(), key=lambda x: x[1])] coords_to_partial_mask_block_index = {} for partial_mask_block_id, coords in block_id_to_block_coords.items(): for coo in coords: coords_to_partial_mask_block_index[coo] = partial_mask_block_id partial_mask_blocks = None has_mask_next = False if len(unique_partial_mask_blocks) >= 1: partial_mask_blocks = [x.array for x in unique_partial_mask_blocks] partial_mask_blocks = np.stack(partial_mask_blocks, axis=0).astype(np.bool_) has_mask_next = True if is_dkv and partial_mask_blocks is not None: partial_mask_blocks = np.swapaxes(partial_mask_blocks, -1, -2) all_head_shards_identical = all(head_shard_to_mask_ids[0] == x and len(x) == 1 for x in head_shard_to_mask_ids) shards_to_process = 1 if all_head_shards_identical else head_shards q_sequence_axis = 1 head_axis = 0 data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(shards_to_process): data_next_sequence_slices, mask_next_sequence_slices = [], [] for q_seq_len_shard in range(q_seq_shards): head_start = head_shard * heads_per_shard q_seq_len_shard_size = q_blocks_per_shard * q_block_size q_seq_len_start = q_seq_len_shard * q_seq_len_shard_size blocked_q_seq_len_start = q_seq_len_shard * q_blocks_per_shard blocked_q_seq_len_slice = slice( blocked_q_seq_len_start, (q_seq_len_shard + 1) * q_blocks_per_shard, ) if masks_per_head_shard == 1: unique_mask = unique_masks[head_to_mask_id[head_start]] unique_mask = mask_lib.MultiHeadMask((unique_mask,)) current_mask = unique_mask mask_head_slice = slice(head_shard, head_shard + 1) else: current_mask = mask mask_head_slice = slice(head_start, (head_shard + 1) * heads_per_shard) mask_info_slice_shape = ( mask_head_slice.stop - mask_head_slice.start, blocked_q_seq_len_slice.stop - blocked_q_seq_len_slice.start, kv_blocks_count, ) data_next_slice, mask_next_slice = _get_mask_info_for_shard( output_shape=mask_info_slice_shape, has_mask_next=has_mask_next, mask=current_mask, block_shape=block_shape, coords_to_partial_mask_block_index=coords_to_partial_mask_block_index, head_start=head_start, masks_per_head_shard=masks_per_head_shard, num_heads=1 if masks_per_head_shard == 1 else heads_per_shard, q_seq_start=q_seq_len_start, q_seq_shard_size=q_seq_len_shard_size, blocked_q_seq_start=blocked_q_seq_len_start, is_dkv=is_dkv, ) data_next_sequence_slices.append(data_next_slice) mask_next_sequence_slices.append(mask_next_slice) data_next_per_head = np.concatenate(data_next_sequence_slices, axis=q_sequence_axis) data_next_per_head_list.append(data_next_per_head) if has_mask_next: mask_next_per_head = np.concatenate(mask_next_sequence_slices, axis=q_sequence_axis) mask_next_per_head_list.append(mask_next_per_head) mask_next = None if all_head_shards_identical: assert len(data_next_per_head_list) == 1 data_next_shard = data_next_per_head_list[0] assert data_next_shard.shape == (1, q_blocks_count, kv_blocks_count) data_next = np.broadcast_to( data_next_shard, (head_shards, q_blocks_count, kv_blocks_count), ) if has_mask_next: assert len(mask_next_per_head_list) == 1 mask_next_shard = mask_next_per_head_list[0] assert mask_next_shard.shape == (1, q_blocks_count, kv_blocks_count) mask_next = np.broadcast_to( mask_next_shard, (head_shards, q_blocks_count, kv_blocks_count), ) else: data_next = np.concatenate(data_next_per_head_list, axis=head_axis) if has_mask_next: mask_next = np.concatenate(mask_next_per_head_list, axis=head_axis) if shrink_grid and block_mask.shape[0] == head_shards and len(unique_masks) == 1: rows_per_q_shard = block_mask.shape[1] // q_seq_shards block_mask_shards = [] data_next_shards = [] mask_next_shards = [] for q_seq_len_shard in range(q_seq_shards): rows = slice( q_seq_len_shard * rows_per_q_shard, (q_seq_len_shard + 1) * rows_per_q_shard, ) current_block_mask = block_mask[:, rows, :] current_data_next = data_next[:, rows, :] current_mask_next = mask_next[:, rows, :] if mask_next is not None else None shrink_function = _shrink_mask_info_dkv if is_dkv else _shrink_mask_info current_block_mask, current_data_next, current_mask_next = shrink_function( block_mask=current_block_mask, data_next=current_data_next, mask_next=current_mask_next, head_shards=head_shards, ) assert current_block_mask.size > 0 assert current_data_next.size > 0 assert current_mask_next is None or current_mask_next.size > 0 assert current_block_mask.shape == current_data_next.shape assert current_mask_next is None or current_block_mask.shape == current_mask_next.shape block_mask_shards.append(current_block_mask) data_next_shards.append(current_data_next) mask_next_shards.append(current_mask_next) if q_seq_shards == 1: block_mask = block_mask_shards[0] data_next = data_next_shards[0] mask_next = mask_next_shards[0] else: padding_axis = 1 if is_dkv else 2 max_size = max(x.shape[padding_axis] for x in block_mask_shards) padded_block_mask_shards = [] padded_data_next_shards = [] padded_mask_next_shards = [] assert len(block_mask_shards) == len(data_next_shards) == len(mask_next_shards) for ( current_block_mask, current_data_next, current_mask_next, ) in zip(block_mask_shards, data_next_shards, mask_next_shards, strict=False): if is_dkv: pad_width = ( (0, 0), (0, max_size - current_block_mask.shape[padding_axis]), (0, 0), ) else: pad_width = ( (0, 0), (0, 0), (0, max_size - current_block_mask.shape[padding_axis]), ) padded_block_mask_shards.append(np.pad(current_block_mask, pad_width=pad_width, constant_values=0)) padded_data_next_shards.append(np.pad(current_data_next, pad_width=pad_width, mode="edge")) if current_mask_next is not None: padded_mask_next_shards.append(np.pad(current_mask_next, pad_width=pad_width, mode="edge")) block_mask = np.concatenate(padded_block_mask_shards, axis=1) data_next = np.concatenate(padded_data_next_shards, axis=1) mask_next = np.concatenate(padded_mask_next_shards, axis=1) if downcast_smem_data: data_next = _downcast_to_small_type(data_next) block_mask = _downcast_to_small_type(block_mask) if mask_next is not None: mask_next = _downcast_to_small_type(mask_next) assert (mask_function is not None) == (q_sequence is not None) return ( MaskInfo( data_next=data_next, mask_next=mask_next if mask_function is None else None, block_mask=block_mask, partial_mask_blocks=partial_mask_blocks if mask_function is None else None, q_sequence=q_sequence, ), mask_function, ) def _shrink_mask_info( *, block_mask: np.ndarray, data_next: np.ndarray, mask_next: np.ndarray, head_shards: int, ): assert block_mask.ndim == 3 assert data_next.ndim == 3 assert mask_next is None or mask_next.ndim == 3 head_block_mask = block_mask[0] grouped_non_zero_cols = [] for row_index in range(head_block_mask.shape[0]): head_block_mask_row = head_block_mask[row_index, :] non_zero_cols = np.nonzero(head_block_mask_row)[0] grouped_non_zero_cols.append(non_zero_cols) max_non_zero_cols = max(len(x) for x in grouped_non_zero_cols) padded_non_zero_cols = [] padding = -1 for row in grouped_non_zero_cols: padded_non_zero_cols.append( np.pad( row, pad_width=(0, max_non_zero_cols - row.shape[0]), constant_values=padding, ) ) padded_non_zero_cols = np.stack(padded_non_zero_cols, axis=0) assert padded_non_zero_cols.shape[0] == block_mask.shape[1], ( padded_non_zero_cols.shape, block_mask.shape, ) def select_cols(array): assert array.ndim == 2 assert padded_non_zero_cols.ndim == 2 assert array.shape[0] == padded_non_zero_cols.shape[0] assert array.shape[1] >= padded_non_zero_cols.shape[1] selected_rows = [] for row in range(array.shape[0]): col = padded_non_zero_cols[row] selected = array[row][col] selected = np.where(col != padding, selected, 0) selected_rows.append(selected) return np.stack(selected_rows, axis=0) return _slice_mask_info( block_mask=block_mask, data_next=data_next, mask_next=mask_next, head_shards=head_shards, slice_function=select_cols, ) def _shrink_mask_info_dkv( *, block_mask: np.ndarray, data_next: np.ndarray, mask_next: np.ndarray, head_shards: int, ): assert block_mask.ndim == 3 assert data_next.ndim == 3 assert mask_next is None or mask_next.ndim == 3 head_block_mask = block_mask[0] grouped_non_zero_rows = [] for col_index in range(head_block_mask.shape[1]): col = head_block_mask[:, col_index] non_zero_rows = np.nonzero(col)[0] grouped_non_zero_rows.append(non_zero_rows) max_non_zero_rows = max(len(x) for x in grouped_non_zero_rows) padded_non_zero_rows = [] padding = -1 for col in grouped_non_zero_rows: padded_non_zero_rows.append( np.pad( col, pad_width=(max_non_zero_rows - col.shape[0], 0), constant_values=padding, ) ) padded_non_zero_rows = np.stack(padded_non_zero_rows, axis=1) assert padded_non_zero_rows.shape[1] == block_mask.shape[2], ( padded_non_zero_rows.shape, block_mask.shape, ) def select_rows(array): assert array.ndim == 2 assert padded_non_zero_rows.ndim == 2 assert array.shape[1] == padded_non_zero_rows.shape[1] assert array.shape[0] >= padded_non_zero_rows.shape[0] selected_cols = [] for col in range(array.shape[1]): row = padded_non_zero_rows[:, col] selected = array[:, col][row] selected = np.where(row != padding, selected, 0) selected_cols.append(selected) return np.stack(selected_cols, axis=1) return _slice_mask_info( block_mask=block_mask, data_next=data_next, mask_next=mask_next, head_shards=head_shards, slice_function=select_rows, ) def _slice_mask_info( *, block_mask: np.ndarray, data_next: np.ndarray, mask_next: np.ndarray, head_shards: int, slice_function: Callable[[np.ndarray], np.ndarray], ): new_block_mask = [] new_data_next = [] new_mask_next = [] for head_shard in range(head_shards): head_block_mask = block_mask[head_shard] head_block_mask = slice_function(head_block_mask) new_block_mask.append(head_block_mask) head_data_next = data_next[head_shard] head_data_next = slice_function(head_data_next) new_data_next.append(head_data_next) if mask_next is not None: head_mask_next = mask_next[head_shard] head_mask_next = slice_function(head_mask_next) new_mask_next.append(head_mask_next) block_mask = np.stack(new_block_mask, axis=0) data_next = np.stack(new_data_next, axis=0) if mask_next is not None: mask_next = np.stack(new_mask_next, axis=0) return block_mask, data_next, mask_next process_mask = functools.partial(_process_mask, is_dkv=False) process_mask_dkv = functools.partial(_process_mask, is_dkv=True) process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False) process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True)