# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grouped matrix multiplication kernels for TPU written in Pallas."""
import functools
import json
from collections.abc import Callable
import jax
import jax.numpy as jnp
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from ejkernel.callib import buffered_pallas_call
def _validate_args(
lhs: jax.Array,
rhs: jax.Array,
group_sizes: jax.Array,
*,
expected_rhs_dims: int = 3,
) -> jax.Array:
"""Validates the arguments for the grouped_matmul function."""
if lhs.ndim != 2:
raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.")
if rhs.ndim != expected_rhs_dims:
raise ValueError(f"Expected {expected_rhs_dims}-tensor for 'rhs' but got {rhs.ndim}-tensor.")
if group_sizes.dtype not in (jnp.int32, jnp.uint32):
raise ValueError(f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}.")
return group_sizes.astype(jnp.int32)
GroupMetadata = tuple[jax.Array, jax.Array, jax.Array]
def _get_store_mask(
*,
grid_id: jax.Array,
group_metadata: GroupMetadata,
tm: int,
tn: int,
) -> jax.Array:
"""Generate a boolean mask for rows belonging to the current group in the current tile.
This function creates a mask to ensure that only the rows that belong to the
current group are written to the output matrix. This is critical for handling
partial tiles where a tile may span multiple groups or where a group doesn't
perfectly align with tile boundaries.
Args:
grid_id: The current grid index being processed.
group_metadata: Tuple containing (group_offsets, group_ids, m_tile_ids)
which define the group-to-tile mapping.
tm: The m-dimension tile size (number of rows per tile).
tn: The n-dimension tile size (number of columns per tile).
Returns:
A boolean array of shape [tm, tn] where True indicates that the
corresponding element belongs to the current group and should be stored.
"""
group_offsets, group_ids, m_tile_ids = group_metadata
group_id = group_ids[grid_id]
group_start = group_offsets[group_id]
group_end = group_offsets[group_id + 1]
m_id = m_tile_ids[grid_id] * tm
iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id
return (iota >= group_start) & (iota < group_end)
_TilingFn = Callable[[int, int, int], tuple[int, int, int] | None]
[docs]@functools.partial(
jax.jit,
static_argnames=[
"preferred_element_type",
"tiling",
"transpose_rhs",
"interpret",
"input_buffer_count",
],
)
def grouped_matmul(
lhs: jax.Array,
rhs: jax.Array,
group_sizes: jax.Array,
preferred_element_type: jnp.dtype,
tiling: tuple[int, int, int] | _TilingFn | None = (128, 128, 128),
input_buffer_count: int = 2,
group_offset: jax.Array | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
) -> jax.Array:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Args:
lhs: A 2d, jax.Array with shape [m, k].
rhs: A 3d, jax.Array with shape [num_groups, k, n].
group_sizes: A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
transpose_rhs: True if the rhs needs to be transposed.
interpret: Whether or not to run the kernel in interpret mode, helpful for
testing and debugging.
Returns:
A 2d, jax.Array with shape [m, n].
"""
group_sizes = _validate_args(lhs, rhs, group_sizes)
if group_offset is None:
group_offset = jnp.array([0], dtype=jnp.int32)
else:
if group_offset.shape:
raise ValueError(f"group_offset must be a ()-shaped array. Got: {group_offset.shape}.")
group_offset = group_offset[None]
m, k = lhs.shape
n = rhs.shape[1 if transpose_rhs else 2]
if callable(tiling):
tiling = tiling(m, k, n)
if tiling is None:
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
tm, tk, tn = tiling
tiles_k = pl.cdiv(k, tk)
tiles_n = pl.cdiv(n, tn)
group_metadata, num_active_tiles = make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
start_group=group_offset[0],
num_nonzero_groups=rhs.shape[0],
visit_empty_groups=False,
)
group_offsets, group_ids, _ = group_metadata
def kernel(group_metadata, _, lhs_ref, rhs_ref, out_ref, acc_scratch):
if transpose_rhs:
dimension_numbers = (((1,), (1,)), ((), ()))
else:
dimension_numbers = (((1,), (0,)), ((), ()))
def dot_general(x, y, preferred_element_type):
return jax.lax.dot_general(
x,
y,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,
)
grid_id = pl.program_id(1)
k_i = pl.program_id(2)
@pl.when(k_i == 0)
def _zero_acc():
acc_scratch[...] = jnp.zeros_like(acc_scratch)
def accum(is_last_k_tile):
with jax.named_scope(f"accum-last_k_tile={is_last_k_tile}"):
lhs = jax.tree.map(lambda x: x[...], lhs_ref)
rhs = jax.tree.map(lambda x: x[...], rhs_ref)
scales = []
if is_last_k_tile and (k_rem := k % tk) != 0:
def iota(x, d):
return lax.broadcasted_iota(jnp.int32, x.shape, d)
lhs = jnp.where(iota(lhs, 1) < k_rem, lhs, 0)
rhs = jnp.where(iota(rhs, 1 if transpose_rhs else 0) < k_rem, rhs, 0)
def is_int(x):
return jnp.issubdtype(x.dtype, jnp.integer)
acc_dtype = jnp.int32 if is_int(lhs) and is_int(rhs) else jnp.float32
out = dot_general(lhs, rhs, acc_dtype)
for scale, axis in scales:
out *= pltpu.repeat(scale, out.shape[axis] // scale.shape[axis], axis)
acc_scratch[...] += out.astype(acc_scratch.dtype)
if is_last_k_tile:
mask = _get_store_mask(grid_id=grid_id, group_metadata=group_metadata, tm=tm, tn=tn)
acc = acc_scratch[...]
acc = jax.lax.select(mask, acc, out_ref[...].astype(acc.dtype))
out_ref[...] = acc.astype(preferred_element_type)
lax.cond(k_i == tiles_k - 1, lambda: accum(True), lambda: accum(False))
def lhs_index_map(n_i, grid_id, k_i, group_metadata, group_offset):
del n_i, group_offset
_, _, m_tile_ids = group_metadata
return m_tile_ids[grid_id], k_i
def rhs_index_map(n_i, grid_id, k_i, group_metadata, group_offset):
_, group_ids, _ = group_metadata
if transpose_rhs:
k_i, n_i = n_i, k_i
return group_ids[grid_id] - group_offset[0], k_i, n_i
def out_index_map(n_i, grid_id, k_i, group_metadata, group_offset):
del k_i, group_offset
_, _, m_tile_ids = group_metadata
return m_tile_ids[grid_id], n_i
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_index_map)
if transpose_rhs:
rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_index_map)
else:
rhs_block_spec = pl.BlockSpec((None, tk, tn), rhs_index_map)
out_block_spec = pl.BlockSpec((tm, tn), out_index_map)
lhs_bytes = jax.tree.reduce(lambda acc, x: acc + x.size * x.itemsize, lhs, 0)
rhs_bytes = k * n * rhs.itemsize
out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize
bytes_accessed = lhs_bytes * tiles_n + rhs_bytes * group_ids.size + out_bytes
cost_estimate = pl.CostEstimate(flops=2 * m * k * n, bytes_accessed=bytes_accessed, transcendentals=0)
kernel_name = f"gmm_{tm}x{tk}x{tn}"
if transpose_rhs:
kernel_name += "_transpose_rhs"
metadata = dict(
prefer_element_type=jnp.dtype(preferred_element_type).name,
tiling=dict(tile_m=tm, tile_k=tk, tile_n=tn),
transpose_rhs=transpose_rhs,
)
call_gmm = buffered_pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[lhs_block_spec, rhs_block_spec],
out_specs=out_block_spec,
grid=(tiles_n, num_active_tiles, tiles_k),
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
),
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
name=kernel_name,
metadata=dict(xprof_metadata=json.dumps(metadata)),
input_buffer_count=(input_buffer_count, 2),
)
with jax.named_scope(kernel_name):
out = call_gmm(group_metadata, group_offset, lhs, rhs)
if rhs.shape[0] < group_sizes.shape[0]:
group_start = group_offsets[group_offset[0]]
group_end = group_offsets[group_offset[0] + rhs.shape[0]]
row_idxs = jnp.arange(out.shape[0], dtype=jnp.int32)
valid_mask = (row_idxs >= group_start) & (row_idxs < group_end)
out = jnp.where(valid_mask[:, None], out, 0)
return out
[docs]@functools.partial(
jax.jit,
static_argnames=[
"preferred_element_type",
"tiling",
"num_actual_groups",
"interpret",
"input_buffer_count",
],
)
def transposed_grouped_matmul(
lhs: jax.Array,
rhs: jax.Array,
group_sizes: jax.Array,
preferred_element_type: jnp.dtype,
tiling: tuple[int, int, int] | _TilingFn | None = (128, 128, 128),
input_buffer_count: int = 2,
group_offset: jax.Array | None = None,
num_actual_groups: int | None = None,
interpret: bool = False,
) -> jax.Array:
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
Args:
lhs: A 2d, jax.Array with shape [k, m].
rhs: A 2d, jax.Array with shape [m, n].
group_sizes: A 1d, jax.Array with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
num_actual_groups: For when num_groups is sharded and we should only compute
the groups that are local, starting from group_offset.
interpret: Whether or not to run the kernel in interpret mode, helpful for
testing and debugging.
Returns:
A 3d, jax.Array with shape [num_groups, k, n].
"""
group_sizes = _validate_args(lhs, rhs, group_sizes, expected_rhs_dims=2)
if group_offset is None:
group_offset = jnp.array([0], dtype=jnp.int32)
else:
group_offset = group_offset[None]
k, m = lhs.shape
n = rhs.shape[1]
lhs = jax.tree.map(lambda x: x.mT, lhs)
num_groups = group_sizes.shape[0]
num_actual_groups = num_actual_groups if num_actual_groups is not None else num_groups
if callable(tiling):
tiling = tiling(m, k, n)
if tiling is None:
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
tm, tk, tn = tiling
tiles_k = pl.cdiv(k, tk)
tiles_n = pl.cdiv(n, tn)
group_metadata, num_active_tiles = make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
start_group=group_offset[0],
num_nonzero_groups=num_actual_groups,
visit_empty_groups=True,
)
def kernel(group_metadata, _, lhs_ref, rhs_ref, out_ref, acc_scratch):
grid_id = pl.program_id(2)
group_offsets, group_ids, _ = group_metadata
group = group_ids[grid_id]
prev_group = group_ids[jnp.where(grid_id > 0, grid_id - 1, 0)]
@pl.when((grid_id == 0) | (group != prev_group))
def _zero_acc():
acc_scratch[...] = jnp.zeros_like(acc_scratch)
group_size = group_offsets[group + 1] - group_offsets[group]
@pl.when(group_size > 0)
def _do():
def dot(x, y, preferred_element_type):
return lax.dot(x, y, preferred_element_type=preferred_element_type)
lhs = jax.tree.map(lambda x: x[...], lhs_ref)
rhs = jax.tree.map(lambda x: x[...], rhs_ref)
scales = []
kwargs = dict(grid_id=grid_id, group_metadata=group_metadata, tm=tm)
lhs = jnp.where(_get_store_mask(**kwargs, tn=tk), lhs, 0)
rhs = jnp.where(_get_store_mask(**kwargs, tn=tn), rhs, 0)
def is_int(x):
return jnp.issubdtype(x.dtype, jnp.integer)
acc_dtype = jnp.int32 if is_int(lhs) and is_int(rhs) else jnp.float32
out = dot(lhs.T, rhs, acc_dtype)
for scale, axis in scales:
out *= pltpu.repeat(scale, out.shape[axis] // scale.shape[axis], axis)
acc_scratch[...] += out.astype(acc_scratch.dtype)
is_end_of_grid = grid_id == (pl.num_programs(2) - 1)
next_group = group_ids[jnp.where(is_end_of_grid, grid_id, grid_id + 1)]
@pl.when(is_end_of_grid | (group != next_group))
def _store_accum():
out_ref[...] = acc_scratch[...].astype(preferred_element_type)
def lhs_index_map(n_i, k_i, grid_id, group_metadata, group_offset):
del n_i, group_offset
_, _, m_tile_ids = group_metadata
return m_tile_ids[grid_id], k_i
def rhs_index_map(n_i, k_i, grid_id, group_metadata, group_offset):
del k_i, group_offset
_, _, m_tile_ids = group_metadata
return m_tile_ids[grid_id], n_i
def out_index_map(n_i, k_i, grid_id, group_metadata, group_offset):
_, group_ids, _ = group_metadata
return group_ids[grid_id] - group_offset[0], k_i, n_i
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_index_map)
rhs_block_spec = pl.BlockSpec((tm, tn), rhs_index_map)
out_block_spec = pl.BlockSpec((None, tk, tn), out_index_map)
lhs_bytes = jax.tree.reduce(lambda acc, x: acc + x.size * x.itemsize, lhs, 0)
rhs_bytes = jax.tree.reduce(lambda acc, x: acc + x.size * x.itemsize, rhs, 0)
out_bytes = (num_actual_groups * k * n) * jnp.dtype(preferred_element_type).itemsize
bytes_accessed = (lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes
cost_estimate = pl.CostEstimate(flops=2 * m * k * n, bytes_accessed=bytes_accessed, transcendentals=0)
kernel_name = f"tgmm_{tm}x{tk}x{tn}"
metadata = dict(
tiling=dict(tile_m=tm, tile_k=tk, tile_n=tn),
prefer_element_type=jnp.dtype(preferred_element_type).name,
num_actual_groups=num_actual_groups,
)
call_gmm = buffered_pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[lhs_block_spec, rhs_block_spec],
out_specs=out_block_spec,
grid=(tiles_n, tiles_k, num_active_tiles),
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
),
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
name=kernel_name,
metadata=dict(xprof_metadata=json.dumps(metadata)),
input_buffer_count=(input_buffer_count, input_buffer_count),
)
with jax.named_scope(kernel_name):
return call_gmm(group_metadata, group_offset, lhs, rhs)