ejkernel.kernels._pallas.tpu.grouped_matmul._interface#
Custom VJP implementation for grouped matrix multiplication.
This module defines the custom forward and backward passes for grouped matrix multiplication operations, enabling efficient automatic differentiation on TPU. It wraps the low-level kernel implementations with JAX’s custom VJP mechanism to provide gradient support.
- ejkernel.kernels._pallas.tpu.grouped_matmul._interface.grouped_matmul(lhs: ~jaxtyping.Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: ~jaxtyping.Int[jaxlib._jax.Array, 'num_groups_or_shards'], preferred_element_type: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, transpose_rhs: bool = False, interpret: bool = False, precision: ~typing.Union[None, str, ~jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], ~jax._src.lax.lax.DotAlgorithm, ~jax._src.lax.lax.DotAlgorithmPreset] = Precision.DEFAULT) Float[jaxlib._jax.Array, 'm n'][source]#
Grouped Matrix Multiplication: Compute separate matrix products for each group.
This function performs grouped matrix multiplication where different row slices of the left-hand side matrix are multiplied with different matrices from the right-hand side tensor. Mathematically, for each group i:
out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]
where start_i and end_i are determined by group_sizes.
The implementation uses Pallas to generate efficient TPU kernels that: - Process multiple groups in a single kernel launch - Handle groups that don’t align with tile boundaries - Support incremental accumulation for memory efficiency - Optimize memory access patterns for TPU’s memory hierarchy
- Parameters
lhs – Left-hand side matrix of shape [m, k] where m is the total number of rows across all groups and k is the inner dimension.
rhs – Right-hand side tensor of shape [num_groups, k, n] containing a separate matrix for each group. Can be transposed if transpose_rhs=True.
group_sizes – 1D array of shape [num_groups] with int32 dtype. Each element specifies the number of rows in lhs belonging to that group. Must sum to m (first dimension of lhs).
preferred_element_type – Output dtype. Defaults to float32. The kernel uses float32 for accumulation regardless, then casts to this type.
tiling – Tile sizes as (tm, tk, tn) tuple, or a callable that returns tile sizes based on problem dimensions. Typical TPU tiles are 128x128. The callable signature is (m, k, n) -> (tm, tk, tn) | None.
group_offset – Starting group index for sharded execution. Defaults to 0. Useful when distributing groups across multiple devices.
existing_out – Optional pre-existing output tensor to accumulate into. Must have shape [m, n] and dtype matching preferred_element_type. Enables incremental computation and memory reuse.
transpose_rhs – If True, expects rhs shape [num_groups, n, k] and transposes during multiplication. Useful for certain memory layouts.
interpret – Run kernel in interpret mode for debugging. Slower but provides better error messages and doesn’t require compilation.
- Returns
Output matrix of shape [m, n] containing the concatenated results of all group matrix multiplications.
- Algorithm Overview:
Validate inputs and determine computation parameters
Create group metadata for efficient tile-to-group mapping
Define Pallas kernel that: - Loads tiles from lhs and group-specific slices from rhs - Accumulates partial products in on-chip memory - Masks and stores results respecting group boundaries
Launch kernel with appropriate grid dimensions
Zero unprocessed regions if doing partial computation
- TPU Optimizations:
Tiles sized to match TPU’s Matrix Multiply Units (typically 128x128)
Prefetch scheduling for hiding memory latency
VMEM scratch space for accumulation to minimize HBM traffic
Efficient masking for partial tiles using TPU’s predication
Dimension semantics hints for XLA compiler optimization
Example
>>> >>> lhs = jnp.randn(300, 64) >>> rhs = jnp.randn(3, 64, 32) >>> group_sizes = jnp.array([100, 150, 50], dtype=jnp.int32) >>> result = grouped_matmul(lhs, rhs, group_sizes)
Notes
The k dimension can have partial tiles (handled via masking)
The m dimension must be divisible by tm for correctness
Empty groups (size 0) are skipped for efficiency
Cost estimation helps XLA make scheduling decisions