Source code for ejkernel.kernels._xla.grouped_matmul._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Grouped matrix multiplication interface using XLA ragged_dot.

This module provides the public API for grouped matrix multiplication
where different row slices of a matrix are multiplied with different
weight matrices. Uses JAX's ragged_dot_general for efficient computation.
"""

from __future__ import annotations

import contextlib
import typing

import jax
import jax.numpy as jnp
import jaxtyping
from beartype import beartype
from jax.experimental import xla_metadata
from jaxtyping import Array, DTypeLike, Float, Int

from ..._registry import Backend, Platform, kernel_registry

if typing.TYPE_CHECKING:
    from ejkernel.kernels._pallas.tpu.grouped_matmul._interface import LutFn

set_xla_metadata = xla_metadata.set_xla_metadata


[docs]@kernel_registry.register("grouped_matmul", Platform.XLA, Backend.ANY) @kernel_registry.register("grouped_matmulv2", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def grouped_matmul( lhs: Float[Array, "m k"], rhs: Float[Array, "num_groups k n"] | Float[Array, "num_groups n k"], group_sizes: Int[Array, "num_groups_or_shards"], preferred_element_type: DTypeLike = jnp.float32, tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), group_offset: Int[Array, "..."] | None = None, existing_out: Float[Array, "m n"] | None = None, transpose_rhs: bool = False, interpret: bool = False, precision: jax.lax.PrecisionLike = jax.lax.Precision.DEFAULT, ) -> Float[Array, "m n"]: """Grouped Matrix Multiplication: Compute separate matrix products for each group. This function performs grouped matrix multiplication where different row slices of the left-hand side matrix are multiplied with different matrices from the right-hand side tensor. Mathematically, for each group i: out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :] where start_i and end_i are determined by group_sizes. The implementation uses Pallas to generate efficient TPU kernels that: - Process multiple groups in a single kernel launch - Handle groups that don't align with tile boundaries - Support incremental accumulation for memory efficiency - Optimize memory access patterns for TPU's memory hierarchy Args: lhs: Left-hand side matrix of shape [m, k] where m is the total number of rows across all groups and k is the inner dimension. rhs: Right-hand side tensor of shape [num_groups, k, n] containing a separate matrix for each group. Can be transposed if transpose_rhs=True. group_sizes: 1D array of shape [num_groups] with int32 dtype. Each element specifies the number of rows in lhs belonging to that group. Must sum to m (first dimension of lhs). preferred_element_type: Output dtype. Defaults to float32. The kernel uses float32 for accumulation regardless, then casts to this type. tiling: Tile sizes as (tm, tk, tn) tuple, or a callable that returns tile sizes based on problem dimensions. Typical TPU tiles are 128x128. The callable signature is (m, k, n) -> (tm, tk, tn) | None. group_offset: Starting group index for sharded execution. Defaults to 0. Useful when distributing groups across multiple devices. existing_out: Optional pre-existing output tensor to accumulate into. Must have shape [m, n] and dtype matching preferred_element_type. Enables incremental computation and memory reuse. transpose_rhs: If True, expects rhs shape [num_groups, n, k] and transposes during multiplication. Useful for certain memory layouts. interpret: Run kernel in interpret mode for debugging. Slower but provides better error messages and doesn't require compilation. Returns: Output matrix of shape [m, n] containing the concatenated results of all group matrix multiplications. Algorithm Overview: 1. Validate inputs and determine computation parameters 2. Create group metadata for efficient tile-to-group mapping 3. Define Pallas kernel that: - Loads tiles from lhs and group-specific slices from rhs - Accumulates partial products in on-chip memory - Masks and stores results respecting group boundaries 4. Launch kernel with appropriate grid dimensions 5. Zero unprocessed regions if doing partial computation TPU Optimizations: - Tiles sized to match TPU's Matrix Multiply Units (typically 128x128) - Prefetch scheduling for hiding memory latency - VMEM scratch space for accumulation to minimize HBM traffic - Efficient masking for partial tiles using TPU's predication - Dimension semantics hints for XLA compiler optimization Example: >>> >>> lhs = jnp.randn(300, 64) >>> rhs = jnp.randn(3, 64, 32) >>> group_sizes = jnp.array([100, 150, 50], dtype=jnp.int32) >>> result = grouped_matmul(lhs, rhs, group_sizes) Notes: - The k dimension can have partial tiles (handled via masking) - The m dimension must be divisible by tm for correctness - Empty groups (size 0) are skipped for efficiency - Cost estimation helps XLA make scheduling decisions """ if tiling is None: manager = contextlib.nullcontext() else: manager = set_xla_metadata(ragged_dot_tiling=",".join([str(t) for t in tiling])) with manager: out = jax.lax.ragged_dot_general( lhs=lhs, rhs=rhs, group_sizes=group_sizes, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers( dot_dimension_numbers=(((1,), (2,)) if transpose_rhs else ((1,), (1,)), ((), ())), lhs_ragged_dimensions=(0,), rhs_group_dimensions=(0,), ), ) return out