ejkernel.kernels._pallas.tpu.blocksparse_attention._info#

Mini-mask creation library.

class ejkernel.kernels._pallas.tpu.blocksparse_attention._info.MaskInfo(data_next: np.ndarray | jax.Array | None, mask_next: np.ndarray | jax.Array | None, block_mask: np.ndarray | jax.Array | None, partial_mask_blocks: np.ndarray | jax.Array | None, q_sequence: np.ndarray | None, is_dynamic_mask: bool = None)[source]#

Bases: NamedTuple

Contains runtime masking information for the Splash attention kernel.

The arrays data_next, mask_next and block_mask are placed in TPU scalar-memory. This is a scarse resource so the mask creation logic attempts to shrink the data-type of these arrays to the smallest possible one. This can be: np.int32, np.int16 or np.int8.

For the arrays data_next, mask_next and block_mask the size of the first dimension can be one of the two following values: num_head or num_head_shards. The first dimension has size: * num_head_shards when there is only one unique mask for each head in a shard. In this case the three arrays are broadcasted to all the heads in the shard. * num_heads when there is more than one unique mask for each head in the shard.

data_next#

An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next kv block index to prefetch.

Type

numpy.ndarray | jax.jaxlib._jax.Array | None

mask_next#

An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array where each entry contains the next mask block index in partial_mask_blocks to prefetch.

Type

numpy.ndarray | jax.jaxlib._jax.Array | None

block_mask#

An integer[num_heads_or_shards, num_q_blocks, num_kv_blocks] NumPy array whose entries can be 0, 1 or 2. An entry of 0 indicates that the corresponding block in the full mask was all zeros. An entry of 1 indicates that the corresponding block in the full mask contained both zeros and ones. An entry of 2 indicates the corresponding block was entirely ones.

Type

numpy.ndarray | jax.jaxlib._jax.Array | None

partial_mask_blocks#

A bool[num_partial_blocks, block_q, block_kv] NumPy array that contains the blocks of the original mask that contained both zeros and ones. The entries in mask_next point to indices in the first axis of this array.

Type

numpy.ndarray | jax.jaxlib._jax.Array | None

q_sequence#

A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length).

Type

numpy.ndarray | None

is_dynamic_mask#

A bool indicating whether the mask is dynamic or static. When True, the leading dimensions of partial_mask_blocks (num_heads, q_blocks, kv_blocks) are not collapsed, allowing us to shard it along those dimensions.

Type

bool

block_mask: numpy.ndarray | jax.jaxlib._jax.Array | None#

Alias for field number 2

data_next: numpy.ndarray | jax.jaxlib._jax.Array | None#

Alias for field number 0

is_dynamic_mask: bool#

Alias for field number 5

mask_next: numpy.ndarray | jax.jaxlib._jax.Array | None#

Alias for field number 1

partial_mask_blocks: numpy.ndarray | jax.jaxlib._jax.Array | None#

Alias for field number 3

q_sequence: numpy.ndarray | None#

Alias for field number 4

ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_dynamic_mask(mask: jax.Array, block_shape: tuple[int, int], *, is_dkv: bool = False, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, None]#

Similar to _process_mask but the mask must be a dynamic array.

Since the mask is dynamic, we can’t know the exact number of partial mask blocks at trace time. Therefore, the entire mask is materialized in partial_mask_blocks.

Note that we can still populate MaskInfo to skip fully-masked blocks.

Parameters
  • mask – A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense mask to process.

  • block_shape – A Tuple[int, int] representing the shape of the Pallas grid block.

  • is_dkv – True if we are processing the dKV mask

  • downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).

  • head_shards – Number of head shards of the mesh in which the kernel is launched.

  • q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.

  • shrink_grid – Whether or not we should apply the grid shrinking optimization. This is currently ignored.

Returns

MaskInfo, a sparse representation of the dense mask.

Raises
  • ValueError – if the input mask is invalid or the block sizes are not

  • compatible with the mask sizes.

ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_dynamic_mask_dkv(mask: jax.Array, block_shape: tuple[int, int], *, is_dkv: bool = True, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, None]#

Similar to _process_mask but the mask must be a dynamic array.

Since the mask is dynamic, we can’t know the exact number of partial mask blocks at trace time. Therefore, the entire mask is materialized in partial_mask_blocks.

Note that we can still populate MaskInfo to skip fully-masked blocks.

Parameters
  • mask – A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense mask to process.

  • block_shape – A Tuple[int, int] representing the shape of the Pallas grid block.

  • is_dkv – True if we are processing the dKV mask

  • downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).

  • head_shards – Number of head shards of the mesh in which the kernel is launched.

  • q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.

  • shrink_grid – Whether or not we should apply the grid shrinking optimization. This is currently ignored.

Returns

MaskInfo, a sparse representation of the dense mask.

Raises
  • ValueError – if the input mask is invalid or the block sizes are not

  • compatible with the mask sizes.

ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_mask(mask: mask_lib.MultiHeadMask, block_shape: tuple[int, int], *, is_dkv: bool = False, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, jax_util.HashableFunction | None]#

Transform a dense mask into a sparse representation.

The number of head and Q sequence shards are needed to create a MaskInfo object that is partitionable (with shmap or PartIR) along these two dimension. In particular for dKV MaskInfo, for each shard the indices of in the data_next array are relative to the current shard. The fwd and dQ MaskInfo objects do not change when sharding along the head or Q dimensions, they would be different if we were to shard along the KV dimension, but the kernel does not support that.

Parameters
  • mask – Dense mask to process.

  • block_shape – Shape of the Pallas grid block.

  • is_dkv – True if we are processing the dKV mask

  • downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).

  • head_shards – Number of head shards of the mesh in which the kernel is launched.

  • q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.

  • shrink_grid – Whether or not we should apply the grid shrinking optimization.

Returns

MaskInfo, a sparse representation of the dense mask. MaskCallable: a callable that, given in input Q and KV indices, returns

the value of the mask at those coordinates.

Raises
  • ValueError – if the input mask is invalid or the block sizes are not

  • compatible with the mask sizes.

ejkernel.kernels._pallas.tpu.blocksparse_attention._info.process_mask_dkv(mask: mask_lib.MultiHeadMask, block_shape: tuple[int, int], *, is_dkv: bool = True, downcast_smem_data: bool = True, head_shards: int = 1, q_seq_shards: int = 1, shrink_grid: bool = True) tuple[MaskInfo, jax_util.HashableFunction | None]#

Transform a dense mask into a sparse representation.

The number of head and Q sequence shards are needed to create a MaskInfo object that is partitionable (with shmap or PartIR) along these two dimension. In particular for dKV MaskInfo, for each shard the indices of in the data_next array are relative to the current shard. The fwd and dQ MaskInfo objects do not change when sharding along the head or Q dimensions, they would be different if we were to shard along the KV dimension, but the kernel does not support that.

Parameters
  • mask – Dense mask to process.

  • block_shape – Shape of the Pallas grid block.

  • is_dkv – True if we are processing the dKV mask

  • downcast_smem_data – If True, downcast the scalar-memory data of MaskInfo to a data type smaller than np.int32 (if possible).

  • head_shards – Number of head shards of the mesh in which the kernel is launched.

  • q_seq_shards – Number of Q sequence shards of the mesh in which the kernel is launched.

  • shrink_grid – Whether or not we should apply the grid shrinking optimization.

Returns

MaskInfo, a sparse representation of the dense mask. MaskCallable: a callable that, given in input Q and KV indices, returns

the value of the mask at those coordinates.

Raises
  • ValueError – if the input mask is invalid or the block sizes are not

  • compatible with the mask sizes.