ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl#
Grouped matrix multiplication kernels for TPU written in Pallas.
- ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.grouped_matmul(lhs: Array, rhs: Array, group_sizes: Array, preferred_element_type: dtype, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), input_buffer_count: int = 2, group_offset: jax.jaxlib._jax.Array | None = None, transpose_rhs: bool = False, interpret: bool = False) Array[source]#
Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group ‘i’.
- Parameters
lhs – A 2d, jax.Array with shape [m, k].
rhs – A 3d, jax.Array with shape [num_groups, k, n].
group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.
preferred_element_type – jnp.dtype, the element type for the output matrix.
tiling – 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
transpose_rhs – True if the rhs needs to be transposed.
interpret – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.
- Returns
A 2d, jax.Array with shape [m, n].
- ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.make_group_metadata(*, group_sizes: Array, m: int, tm: int, start_group: Array, num_nonzero_groups: int, visit_empty_groups: bool) tuple[tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array][source]#
Create the metadata needed for grouped matmul computation.
- Parameters
group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.
m – The number of rows in lhs.
tm – The m-dimension tile size being used.
start_group – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
num_nonzero_groups – Number of groups in group sizes to compute on. Useful in combination with group_offset.
visit_empty_groups – If True, do not squeeze tiles for empty groups out of the metadata. This is necessary for transposed_grouped_matmul, where we at least need to zero the output for each group.
- Returns
- group_offsets: A 1d, jax.Array with shape [num_groups+1] and jnp.int32
dtype. group_offsets[i] indicates the row at which group [i] starts in the lhs matrix and group_offsets[i-1] = m.
- group_ids: A 1d, jax.Array with shape [m_tiles + num_groups] and
jnp.int32 dtype. group_ids[i] indicates which group grid index ‘i’ will work on.
- m_tile_ids: A 1d, jax.Array with shape [m_tiles + num_groups] and
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index ‘i’ will work on.
num_tiles: The number of m-dimension tiles to execute.
- Return type
tuple of
- ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.transposed_grouped_matmul(lhs: Array, rhs: Array, group_sizes: Array, preferred_element_type: dtype, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), input_buffer_count: int = 2, group_offset: jax.jaxlib._jax.Array | None = None, num_actual_groups: int | None = None, interpret: bool = False) Array[source]#
Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
- Parameters
lhs – A 2d, jax.Array with shape [k, m].
rhs – A 2d, jax.Array with shape [m, n].
group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.
preferred_element_type – jnp.dtype, the element type for the output matrix.
tiling – 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
num_actual_groups – For when num_groups is sharded and we should only compute the groups that are local, starting from group_offset.
interpret – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.
- Returns
A 3d, jax.Array with shape [num_groups, k, n].