ejkernel.kernels._pallas.tpu.grouped_matmul._pallas_impl#
- ejkernel.kernels._pallas.tpu.grouped_matmul._pallas_impl.grouped_matmul(lhs: ~jax.jaxlib._jax.Array, rhs: ~jax.jaxlib._jax.Array, group_sizes: ~jax.jaxlib._jax.Array, preferred_element_type: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), group_offset: jax.jaxlib._jax.Array | None = None, existing_out: jax.jaxlib._jax.Array | None = None, transpose_rhs: bool = False, interpret: bool = False) Array[source]#
Grouped Matrix Multiplication: Compute separate matrix products for each group.
This function performs grouped matrix multiplication where different row slices of the left-hand side matrix are multiplied with different matrices from the right-hand side tensor. Mathematically, for each group i:
out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]
where start_i and end_i are determined by group_sizes.
The implementation uses Pallas to generate efficient TPU kernels that: - Process multiple groups in a single kernel launch - Handle groups that don’t align with tile boundaries - Support incremental accumulation for memory efficiency - Optimize memory access patterns for TPU’s memory hierarchy
- Parameters
lhs – Left-hand side matrix of shape [m, k] where m is the total number of rows across all groups and k is the inner dimension.
rhs – Right-hand side tensor of shape [num_groups, k, n] containing a separate matrix for each group. Can be transposed if transpose_rhs=True.
group_sizes – 1D array of shape [num_groups] with int32 dtype. Each element specifies the number of rows in lhs belonging to that group. Must sum to m (first dimension of lhs).
preferred_element_type – Output dtype. Defaults to float32. The kernel uses float32 for accumulation regardless, then casts to this type.
tiling – Tile sizes as (tm, tk, tn) tuple, or a callable that returns tile sizes based on problem dimensions. Typical TPU tiles are 128x128. The callable signature is (m, k, n) -> (tm, tk, tn) | None.
group_offset – Starting group index for sharded execution. Defaults to 0. Useful when distributing groups across multiple devices.
existing_out – Optional pre-existing output tensor to accumulate into. Must have shape [m, n] and dtype matching preferred_element_type. Enables incremental computation and memory reuse.
transpose_rhs – If True, expects rhs shape [num_groups, n, k] and transposes during multiplication. Useful for certain memory layouts.
interpret – Run kernel in interpret mode for debugging. Slower but provides better error messages and doesn’t require compilation.
- Returns
Output matrix of shape [m, n] containing the concatenated results of all group matrix multiplications.
- Algorithm Overview:
Validate inputs and determine computation parameters
Create group metadata for efficient tile-to-group mapping
Define Pallas kernel that: - Loads tiles from lhs and group-specific slices from rhs - Accumulates partial products in on-chip memory - Masks and stores results respecting group boundaries
Launch kernel with appropriate grid dimensions
Zero unprocessed regions if doing partial computation
- TPU Optimizations:
Tiles sized to match TPU’s Matrix Multiply Units (typically 128x128)
Prefetch scheduling for hiding memory latency
VMEM scratch space for accumulation to minimize HBM traffic
Efficient masking for partial tiles using TPU’s predication
Dimension semantics hints for XLA compiler optimization
Example
>>> >>> lhs = jnp.randn(300, 64) >>> rhs = jnp.randn(3, 64, 32) >>> group_sizes = jnp.array([100, 150, 50], dtype=jnp.int32) >>> result = grouped_matmul(lhs, rhs, group_sizes)
Notes
The k dimension can have partial tiles (handled via masking)
The m dimension must be divisible by tm for correctness
Empty groups (size 0) are skipped for efficiency
Cost estimation helps XLA make scheduling decisions
- ejkernel.kernels._pallas.tpu.grouped_matmul._pallas_impl.make_group_metadata(*, group_sizes: Array, m: int, tm: int, start_group: Array, num_nonzero_groups: int, visit_empty_groups: bool = True) Any[source]#
Create metadata for efficient grouped matrix multiplication on TPU.
This function generates the metadata structures needed to efficiently execute grouped matrix multiplication on TPU hardware. It handles the complex mapping between groups and tiles, accounting for groups that may not align with tile boundaries and managing partial tiles that span multiple groups.
The algorithm works by: 1. Computing group offsets (CSR-like row offsets for each group) 2. Calculating tile assignments for each group, handling boundary cases 3. Creating mappings from grid indices to group IDs and tile IDs 4. Accounting for sharding when groups are distributed across devices
- Parameters
group_sizes – 1D array of shape [num_groups] with int32 dtype. Each element specifies the number of rows in the corresponding group. Must sum to m.
m – Total number of rows across all groups in the left-hand side matrix.
tm – Tile size for the m dimension. Must evenly divide m for correctness.
start_group – Scalar indicating the first group to process (0-indexed). Used for sharding groups across multiple devices.
num_nonzero_groups – Number of consecutive groups to process starting from start_group. Enables processing a subset of groups.
visit_empty_groups – If True, allocate tiles for groups with size 0. Required for transposed_grouped_matmul to ensure output is properly zeroed for empty groups. If False, empty groups are skipped (used in grouped_matmul for efficiency).
- Returns
- group_metadata: A tuple of three arrays:
group_offsets: Shape [num_groups+1], int32. CSR-style offsets where group_offsets[i] is the starting row of group i, and group_offsets[num_groups] = m.
group_ids: Shape [m_tiles + num_groups - 1], int32. Maps each grid index to its corresponding group ID.
m_tile_ids: Shape [m_tiles + num_groups - 1], int32. Maps each grid index to its corresponding m-dimension tile ID.
num_tiles: Total number of tiles to execute for the specified groups.
- Return type
A tuple containing
- Algorithm Details:
The function handles several complex cases: - Groups that don’t start or end on tile boundaries require partial tile processing - Tiles that span multiple groups need to be visited multiple times - Empty groups may need special handling depending on the operation (grouped_matmul vs transposed_grouped_matmul) - Sharding requires adjusting the metadata to process only local groups
- TPU Optimizations:
Tiles are sized to match TPU’s native matrix multiply units (typically 128x128)
Metadata is structured to minimize memory access patterns
Grid layout ensures coalesced memory access and efficient tile reuse
Partial tiles are handled through masking rather than padding to save memory
Example
>>> group_sizes = jnp.array([100, 150, 50], dtype=jnp.int32) >>> metadata, num_tiles = make_group_metadata( ... group_sizes=group_sizes, m=300, tm=128, ... start_group=jnp.array(0), num_nonzero_groups=3)
- ejkernel.kernels._pallas.tpu.grouped_matmul._pallas_impl.transposed_grouped_matmul(lhs: ~jax.jaxlib._jax.Array, rhs: ~jax.jaxlib._jax.Array, group_sizes: ~jax.jaxlib._jax.Array, preferred_element_type: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>, tiling: tuple[int, int, int] | collections.abc.Callable[[int, int, int], tuple[int, int, int] | None] | None = (128, 128, 128), group_offset: jax.jaxlib._jax.Array | None = None, num_actual_groups: int | None = None, existing_out: jax.jaxlib._jax.Array | None = None, interpret: bool = False) Array[source]#
Transposed Grouped Matrix Multiplication: Compute grouped products with transposed access pattern.
This function performs grouped matrix multiplication where different column slices of the left-hand side matrix are multiplied with different row slices of the right-hand side matrix, producing a separate output matrix for each group. Mathematically, for each group i:
out[i, :, :] = lhs[:, start_i:end_i] @ rhs[start_i:end_i, :]
where start_i and end_i are determined by cumulative group_sizes.
This operation is particularly useful for: - Attention mechanisms where different heads process different feature slices - Expert routing in Mixture-of-Experts models - Block-sparse operations where groups represent different blocks
The implementation uses Pallas to generate efficient TPU kernels that: - Process multiple groups while maintaining separate outputs - Handle empty groups by zeroing their outputs - Support incremental accumulation across tiles - Optimize for TPU’s memory hierarchy and compute units
- Parameters
lhs – Left-hand side matrix of shape [k, m] where k is the output dimension and m is the total size across all groups.
rhs – Right-hand side matrix of shape [m, n] where m matches lhs and n is the final output dimension.
group_sizes – 1D array of shape [num_groups] with int32 dtype. Each element specifies the size of that group in the m dimension. Must sum to m.
preferred_element_type – Output dtype. Defaults to float32. Internal accumulation uses float32 regardless, with final cast to this type.
tiling – Tile sizes as (tm, tk, tn) tuple, or a callable that returns tile sizes based on problem dimensions. Standard TPU tiles are 128x128. The callable signature is (m, k, n) -> (tm, tk, tn) | None.
group_offset – Starting group index for sharded execution. Defaults to 0. Enables distributing groups across multiple devices.
num_actual_groups – Number of groups to actually compute starting from group_offset. Defaults to all remaining groups. Useful for sharding.
existing_out – Optional pre-existing output tensor to accumulate into. Must have shape [num_actual_groups, k, n] and matching dtype. Enables incremental computation and gradient accumulation.
interpret – Run kernel in interpret mode for debugging. Slower but provides better error messages and avoids compilation.
- Returns
3D output tensor of shape [num_actual_groups, k, n] where each slice out[i] contains the matrix product for group i.
- Algorithm Overview:
Validate inputs and configure computation parameters
Create group metadata with visit_empty_groups=True to ensure all outputs are properly initialized (even for empty groups)
Define Pallas kernel that: - Maintains separate accumulator for each group - Masks inputs based on group boundaries - Handles group transitions by storing/resetting accumulators - Zeros output for empty groups
Launch kernel with grid covering all tiles and groups
Handle output accumulation if existing_out provided
- TPU Optimizations:
Tile operations aligned with TPU’s 128x128 systolic arrays
Accumulation in VMEM (on-chip memory) to minimize HBM bandwidth
Prefetch scheduling to overlap compute and memory operations
Efficient masking using TPU’s predicated execution
Group transitions handled without kernel restarts
- Key Differences from grouped_matmul:
Output is 3D with separate matrix per group (vs 2D concatenated)
Groups index into both lhs columns and rhs rows (vs only lhs rows)
Empty groups must be visited to zero their outputs
Accumulator management includes group transition logic
Example
>>> >>> lhs = jnp.randn(64, 300) >>> rhs = jnp.randn(300, 32) >>> group_sizes = jnp.array([100, 150, 50], dtype=jnp.int32) >>> result = transposed_grouped_matmul(lhs, rhs, group_sizes) >>> >>> >>>
Notes
The m dimension must be divisible by tm for correctness
Empty groups produce zero matrices in the output
Partial tiles are handled through masking
Cost estimation guides XLA’s scheduling decisions
The lhs matrix is internally transposed for efficient access patterns