ejkernel.kernels._pallas.tpu.blocksparse_attention._info#
Mini-mask creation library.
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo(data_next: np.ndarray | jax.Array | None, mask_next: np.ndarray | jax.Array | None, block_mask: np.ndarray | jax.Array | None, partial_mask_blocks: np.ndarray | jax.Array | None, q_sequence: np.ndarray | None, is_dynamic_mask: bool = None)[source]#
Bases:
NamedTupleContains runtime masking information for the Splash attention kernel.
The arrays data_next, mask_next and block_mask are placed in TPU scalar-memory. This is a scarse resource so the mask creation logic attempts to shrink the data-type of these arrays to the smallest possible one. This can be: np.int32, np.int16 or np.int8.
For the arrays data_next, mask_next and block_mask the size of the first dimension can be one of the two following values: num_head or num_head_shards. The first dimension has size: * num_head_shards when there is only one unique mask for each head in a shard. In this case the three arrays are broadcasted to all the heads in the shard. * num_heads when there is more than one unique mask for each head in the shard.
- data_next#
An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next kv block index to prefetch.
- Type
numpy.ndarray | jax.jaxlib._jax.Array | None
- mask_next#
An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next mask block index in partial_mask_blocks to prefetch.
- Type
numpy.ndarray | jax.jaxlib._jax.Array | None
- block_mask#
An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that the corresponding block in the full mask was all zeros. An entry of 1 indicates that the corresponding block in the full mask contained both zeros and ones. An entry of 2 indicates the corresponding block was entirely ones.
- Type
numpy.ndarray | jax.jaxlib._jax.Array | None
- partial_mask_blocks#
A bool[num_partial_blocks, block_q, block_kv] NumPy array that contains the blocks of the original mask that contained both zeros and ones. The entries in mask_next point to indices in the first axis of this array.
- Type
numpy.ndarray | jax.jaxlib._jax.Array | None
- q_sequence#
A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length).
- Type
numpy.ndarray | None
- is_dynamic_mask#
A bool indicating whether the mask is dynamic or static. When True, the leading dimensions of partial_mask_blocks (num_heads, q_blocks, kv_blocks) are not collapsed, allowing us to shard it along those dimensions.
- Type
bool
- block_mask: numpy.ndarray | jax.jaxlib._jax.Array | None#
Alias for field number 2
- data_next: numpy.ndarray | jax.jaxlib._jax.Array | None#
Alias for field number 0
- is_dynamic_mask: bool#
Alias for field number 5
- mask_next: numpy.ndarray | jax.jaxlib._jax.Array | None#
Alias for field number 1
- partial_mask_blocks: numpy.ndarray | jax.jaxlib._jax.Array | None#
Alias for field number 3
- q_sequence: numpy.ndarray | None#
Alias for field number 4
- ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_dynamic_mask(mask: jax.Array, block_shape: tuple[int, int], *, is_dkv: bool = False, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, None]#
Similar to _process_mask but the mask must be a dynamic array.
Since the mask is dynamic, we can’t know the exact number of partial mask blocks at trace time. Therefore, the entire mask is materialized in partial_mask_blocks.
Note that we can still populate MaskInfo to skip fully-masked blocks.
- Parameters
mask – A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense mask to process.
block_shape – A Tuple[int, int] representing the shape of the Pallas grid block.
is_dkv – True if we are processing the dKV mask
downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).
head_shards – Number of head shards of the mesh in which the kernel is launched.
q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.
shrink_grid – Whether or not we should apply the grid shrinking optimization. This is currently ignored.
- Returns
MaskInfo, a sparse representation of the dense mask.
- Raises
ValueError – if the input mask is invalid or the block sizes are not
compatible with the mask sizes. –
- ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_dynamic_mask_dkv(mask: jax.Array, block_shape: tuple[int, int], *, is_dkv: bool = True, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, None]#
Similar to _process_mask but the mask must be a dynamic array.
Since the mask is dynamic, we can’t know the exact number of partial mask blocks at trace time. Therefore, the entire mask is materialized in partial_mask_blocks.
Note that we can still populate MaskInfo to skip fully-masked blocks.
- Parameters
mask – A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense mask to process.
block_shape – A Tuple[int, int] representing the shape of the Pallas grid block.
is_dkv – True if we are processing the dKV mask
downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).
head_shards – Number of head shards of the mesh in which the kernel is launched.
q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.
shrink_grid – Whether or not we should apply the grid shrinking optimization. This is currently ignored.
- Returns
MaskInfo, a sparse representation of the dense mask.
- Raises
ValueError – if the input mask is invalid or the block sizes are not
compatible with the mask sizes. –
- ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_mask(mask: mask_lib.MultiHeadMask, block_shape: tuple[int, int], *, is_dkv: bool = False, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, jax_util.HashableFunction | None]#
Transform a dense mask into a sparse representation.
The number of head and Q sequence shards are needed to create a MaskInfo object that is partitionable (with shmap or PartIR) along these two dimension. In particular for dKV MaskInfo, for each shard the indices of in the data_next array are relative to the current shard. The fwd and dQ MaskInfo objects do not change when sharding along the head or Q dimensions, they would be different if we were to shard along the KV dimension, but the kernel does not support that.
- Parameters
mask – Dense mask to process.
block_shape – Shape of the Pallas grid block.
is_dkv – True if we are processing the dKV mask
downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).
head_shards – Number of head shards of the mesh in which the kernel is launched.
q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.
shrink_grid – Whether or not we should apply the grid shrinking optimization.
- Returns
MaskInfo, a sparse representation of the dense mask. MaskCallable: a callable that, given in input Q and KV indices, returns
the value of the mask at those coordinates.
- Raises
ValueError – if the input mask is invalid or the block sizes are not
compatible with the mask sizes. –
- ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_mask_dkv(mask: mask_lib.MultiHeadMask, block_shape: tuple[int, int], *, is_dkv: bool = True, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, jax_util.HashableFunction | None]#
Transform a dense mask into a sparse representation.
The number of head and Q sequence shards are needed to create a MaskInfo object that is partitionable (with shmap or PartIR) along these two dimension. In particular for dKV MaskInfo, for each shard the indices of in the data_next array are relative to the current shard. The fwd and dQ MaskInfo objects do not change when sharding along the head or Q dimensions, they would be different if we were to shard along the KV dimension, but the kernel does not support that.
- Parameters
mask – Dense mask to process.
block_shape – Shape of the Pallas grid block.
is_dkv – True if we are processing the dKV mask
downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).
head_shards – Number of head shards of the mesh in which the kernel is launched.
q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.
shrink_grid – Whether or not we should apply the grid shrinking optimization.
- Returns
MaskInfo, a sparse representation of the dense mask. MaskCallable: a callable that, given in input Q and KV indices, returns
the value of the mask at those coordinates.
- Raises
ValueError – if the input mask is invalid or the block sizes are not
compatible with the mask sizes. –