ejkernel.modules.operations.grouped_matmul#
Grouped matrix multiplication kernel module with automatic optimization.
This module implements grouped matrix multiplication, an efficient operation for batched matrix multiplication with variable group sizes. This is particularly useful for mixture-of-experts models, grouped convolutions, and other scenarios where different groups of inputs need to be multiplied with different weight matrices.
Unlike standard batched matrix multiplication which assumes uniform batch sizes, grouped matmul handles variable-sized groups efficiently by:
Processing groups of different sizes in a single operation
Optimized memory access patterns for grouped computation
Support for both transposed and non-transposed RHS matrices
Optional accumulation into existing output tensors
- class ejkernel.modules.operations.grouped_matmul.GroupedMatmul(use_v2: bool = False)[source]#
Bases:
Kernel[GroupedMatmulConfig,Array]Grouped Matrix Multiplication with custom optimization logic.
Performs efficient matrix multiplication for grouped inputs, where each group can have a different size. This is essential for mixture-of-experts (MoE) models where tokens are dynamically routed to different experts.
- Features:
Variable group size support via group_sizes array
Configurable tiling for memory and compute efficiency
Support for RHS transposition
Optional output accumulation (for multi-pass operations)
Group offset for partial computation
Multiple platform support (Triton/Pallas/CUDA/XLA)
- Typical use cases:
MoE layer computation (different tokens to different experts)
Grouped linear layers
Dynamic routing architectures
- candidate_cfgs(inv: Invocation[GroupedMatmulConfig, Array])[source]#
Generate candidate configurations for autotuning.
Creates configurations with different block sizes to explore the performance space. Grouped matmul benefits from various tile sizes depending on group size distribution and matrix dimensions.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations
- create_shard_map_wrapper(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], preferred_element_type=None, group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.GroupedMatmulConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper specifically for blocksparse attention.
- Parameters
mesh – JAX device mesh
in_specs – Input partition specs (must match length of tensor args)
out_specs – Output partition spec
query – Input tensors to be sharded
key – Input tensors to be sharded
value – Input tensors to be sharded
args (All other) – Blocksparse attention parameters
- Returns
Tuple of (shard_map_fn, call_args)
- get_impl(cfg: GroupedMatmulConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for grouped matmul
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[GroupedMatmulConfig, Array]) GroupedMatmulConfig[source]#
Provide default configuration with block sizes.
Selects balanced block sizes suitable for typical grouped matmul workloads. The default 128x128x128 tiling provides good cache utilization for most problem sizes.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration with 128x128x128 blocks, 4 warps, 2 stages
- run(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], preferred_element_type=None, group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: GroupedMatmulConfig) Float[jaxlib._jax.Array, 'm n'][source]#
Execute grouped matrix multiplication.
- Parameters
lhs – Left-hand side matrix [m, k] (typically activations)
rhs – Right-hand side grouped matrices [num_groups, k, n] or [num_groups, n, k]
group_sizes – Size of each group [num_groups], sum(group_sizes) == m
preferred_element_type – Optional dtype for computation
tiling – Tile sizes (m_tile, n_tile, k_tile) for blocking strategy
group_offset – Optional offset into groups for partial computation
existing_out – Optional existing output to accumulate into [m, n]
transpose_rhs – Whether RHS matrices are transposed
interpret – Use interpreted mode (for debugging)
precision – Computation precision setting
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Matrix multiplication result [m, n]
Note
The group_sizes array partitions the m dimension of lhs. Each partition is multiplied with the corresponding group matrix from rhs.
- ejkernel.modules.operations.grouped_matmul.grouped_matmul(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, /, *, preferred_element_type=None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, use_v2: bool = False, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.GroupedMatmulConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'm n'][source]#
Execute grouped matrix multiplication with automatic optimization.
Performs efficient batched matrix multiplication with variable group sizes, optimized for scenarios where different groups have different sizes.
- Parameters
lhs – Left-hand side matrix [m, k]
rhs – Right-hand side matrices [num_groups, k, n] or [num_groups, n, k]
group_sizes – Size of each group [num_groups]
preferred_element_type – Preferred dtype for computation
tiling – Tile sizes (m_tile, n_tile, k_tile) for blocking
group_offset – Offset into groups (for partial computation)
existing_out – Existing output to accumulate into
transpose_rhs – Whether to transpose RHS matrices
interpret – Use interpreted mode (slower but more debuggable)
precision –
Computation precision setting
platform: Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Matrix multiplication result [m, n]
Example
>>> >>> out = grouped_matmul(lhs, rhs, group_sizes) >>> >>> out = grouped_matmul(lhs, rhs, group_sizes) >>> >>> out = grouped_matmul(lhs, rhs_transposed, group_sizes, transpose_rhs=True) >>> >>> out = grouped_matmul(lhs, rhs, group_sizes, existing_out=prev_out) >>> >>> out = grouped_matmul(..., platform="pallas")