ejkernel.modules.operations.grouped_matmul#

Grouped matrix multiplication kernel module with automatic optimization.

This module implements grouped matrix multiplication, an efficient operation for batched matrix multiplication with variable group sizes. This is particularly useful for mixture-of-experts models, grouped convolutions, and other scenarios where different groups of inputs need to be multiplied with different weight matrices.

Unlike standard batched matrix multiplication which assumes uniform batch sizes, grouped matmul handles variable-sized groups efficiently by:

  1. Processing groups of different sizes in a single operation

  2. Optimized memory access patterns for grouped computation

  3. Support for both transposed and non-transposed RHS matrices

  4. Optional accumulation into existing output tensors

class ejkernel.modules.operations.grouped_matmul.GroupedMatmul(use_v2: bool = False)[source]#

Bases: Kernel[GroupedMatmulConfig, Array]

Grouped Matrix Multiplication with custom optimization logic.

Performs efficient matrix multiplication for grouped inputs, where each group can have a different size. This is essential for mixture-of-experts (MoE) models where tokens are dynamically routed to different experts.

Features:
  • Variable group size support via group_sizes array

  • Configurable tiling for memory and compute efficiency

  • Support for RHS transposition

  • Optional output accumulation (for multi-pass operations)

  • Group offset for partial computation

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

Typical use cases:
  • MoE layer computation (different tokens to different experts)

  • Grouped linear layers

  • Dynamic routing architectures

candidate_cfgs(inv: Invocation[GroupedMatmulConfig, Array])[source]#

Generate candidate configurations for autotuning.

Creates configurations with different block sizes to explore the performance space. Grouped matmul benefits from various tile sizes depending on group size distribution and matrix dimensions.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations

create_shard_map_wrapper(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], preferred_element_type=None, group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.GroupedMatmulConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper specifically for blocksparse attention.

Parameters
  • mesh – JAX device mesh

  • in_specs – Input partition specs (must match length of tensor args)

  • out_specs – Output partition spec

  • query – Input tensors to be sharded

  • key – Input tensors to be sharded

  • value – Input tensors to be sharded

  • args (All other) – Blocksparse attention parameters

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: GroupedMatmulConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for grouped matmul

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[GroupedMatmulConfig, Array]) GroupedMatmulConfig[source]#

Provide default configuration with block sizes.

Selects balanced block sizes suitable for typical grouped matmul workloads. The default 128x128x128 tiling provides good cache utilization for most problem sizes.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with 128x128x128 blocks, 4 warps, 2 stages

run(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], preferred_element_type=None, group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: GroupedMatmulConfig) Float[jaxlib._jax.Array, 'm n'][source]#

Execute grouped matrix multiplication.

Parameters
  • lhs – Left-hand side matrix [m, k] (typically activations)

  • rhs – Right-hand side grouped matrices [num_groups, k, n] or [num_groups, n, k]

  • group_sizes – Size of each group [num_groups], sum(group_sizes) == m

  • preferred_element_type – Optional dtype for computation

  • tiling – Tile sizes (m_tile, n_tile, k_tile) for blocking strategy

  • group_offset – Optional offset into groups for partial computation

  • existing_out – Optional existing output to accumulate into [m, n]

  • transpose_rhs – Whether RHS matrices are transposed

  • interpret – Use interpreted mode (for debugging)

  • precision – Computation precision setting

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Matrix multiplication result [m, n]

Note

The group_sizes array partitions the m dimension of lhs. Each partition is multiplied with the corresponding group matrix from rhs.

ejkernel.modules.operations.grouped_matmul.grouped_matmul(lhs: Float[jaxlib._jax.Array, 'm k'], rhs: jaxtyping.Float[jaxlib._jax.Array, 'num_groups k n'] | jaxtyping.Float[jaxlib._jax.Array, 'num_groups n k'], group_sizes: Int[jaxlib._jax.Array, 'num_groups_or_shards'], group_offset: jaxtyping.Int[jaxlib._jax.Array, '...'] | None = None, existing_out: jaxtyping.Float[jaxlib._jax.Array, 'm n'] | None = None, /, *, preferred_element_type=None, transpose_rhs: bool = False, interpret: bool = False, do_padding: bool = True, precision=None, use_v2: bool = False, out_shard_callback: collections.abc.Callable[[jaxtyping.Float[jaxlib._jax.Array, 'm n']], jaxtyping.Float[jaxlib._jax.Array, 'm n']] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.GroupedMatmulConfig | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec | None, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'm n'][source]#

Execute grouped matrix multiplication with automatic optimization.

Performs efficient batched matrix multiplication with variable group sizes, optimized for scenarios where different groups have different sizes.

Parameters
  • lhs – Left-hand side matrix [m, k]

  • rhs – Right-hand side matrices [num_groups, k, n] or [num_groups, n, k]

  • group_sizes – Size of each group [num_groups]

  • preferred_element_type – Preferred dtype for computation

  • tiling – Tile sizes (m_tile, n_tile, k_tile) for blocking

  • group_offset – Offset into groups (for partial computation)

  • existing_out – Existing output to accumulate into

  • transpose_rhs – Whether to transpose RHS matrices

  • interpret – Use interpreted mode (slower but more debuggable)

  • precision

    Computation precision setting

    platform: Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Matrix multiplication result [m, n]

Example

>>>
>>> out = grouped_matmul(lhs, rhs, group_sizes)
>>>
>>> out = grouped_matmul(lhs, rhs, group_sizes)
>>>
>>> out = grouped_matmul(lhs, rhs_transposed, group_sizes, transpose_rhs=True)
>>>
>>> out = grouped_matmul(lhs, rhs, group_sizes, existing_out=prev_out)
>>>
>>> out = grouped_matmul(..., platform="pallas")