ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl

ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl#

Grouped matrix multiplication kernels for TPU written in Pallas.

ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.grouped_matmul(lhs: Array, rhs: Array, group_sizes: Array, preferred_element_type: dtype, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), input_buffer_count: int = 2, group_offset: jax.jaxlib._jax.Array | None = None, transpose_rhs: bool = False, interpret: bool = False) Array[source]#

Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group ‘i’.

Parameters
  • lhs – A 2d, jax.Array with shape [m, k].

  • rhs – A 3d, jax.Array with shape [num_groups, k, n].

  • group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.

  • preferred_element_type – jnp.dtype, the element type for the output matrix.

  • tiling – 3-tuple of ints. The m, k and n-dimension tile sizes.

  • group_offset – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.

  • transpose_rhs – True if the rhs needs to be transposed.

  • interpret – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.

Returns

A 2d, jax.Array with shape [m, n].

ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.make_group_metadata(*, group_sizes: Array, m: int, tm: int, start_group: Array, num_nonzero_groups: int, visit_empty_groups: bool) tuple[tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array][source]#

Create the metadata needed for grouped matmul computation.

Parameters
  • group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.

  • m – The number of rows in lhs.

  • tm – The m-dimension tile size being used.

  • start_group – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.

  • num_nonzero_groups – Number of groups in group sizes to compute on. Useful in combination with group_offset.

  • visit_empty_groups – If True, do not squeeze tiles for empty groups out of the metadata. This is necessary for transposed_grouped_matmul, where we at least need to zero the output for each group.

Returns

group_offsets: A 1d, jax.Array with shape [num_groups+1] and jnp.int32

dtype. group_offsets[i] indicates the row at which group [i] starts in the lhs matrix and group_offsets[i-1] = m.

group_ids: A 1d, jax.Array with shape [m_tiles + num_groups] and

jnp.int32 dtype. group_ids[i] indicates which group grid index ‘i’ will work on.

m_tile_ids: A 1d, jax.Array with shape [m_tiles + num_groups] and

jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index ‘i’ will work on.

num_tiles: The number of m-dimension tiles to execute.

Return type

tuple of

ejkernel.kernels._pallas.tpu.grouped_matmulv2._pallas_impl.transposed_grouped_matmul(lhs: Array, rhs: Array, group_sizes: Array, preferred_element_type: dtype, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), input_buffer_count: int = 2, group_offset: jax.jaxlib._jax.Array | None = None, num_actual_groups: int | None = None, interpret: bool = False) Array[source]#

Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].

Parameters
  • lhs – A 2d, jax.Array with shape [k, m].

  • rhs – A 2d, jax.Array with shape [m, n].

  • group_sizes – A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.

  • preferred_element_type – jnp.dtype, the element type for the output matrix.

  • tiling – 3-tuple of ints. The m, k and n-dimension tile sizes.

  • group_offset – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.

  • num_actual_groups – For when num_groups is sharded and we should only compute the groups that are local, starting from group_offset.

  • interpret – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.

Returns

A 3d, jax.Array with shape [num_groups, k, n].