ejkernel.kernels._pallas.tpu.grouped_matmul._utils#
Utility functions for TPU device detection and dtype validation.
This module provides helper functions to detect TPU hardware capabilities and validate data types for grouped matrix multiplication operations. It handles TPU generation detection and dtype compatibility checks.
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.assert_is_supported_dtype(dtype: Union[str, type[Any], dtype, SupportsDType]) None[source]#
Validate that a dtype is supported for grouped matrix multiplication.
The grouped matmul kernels are optimized for bfloat16 and float32 dtypes, which provide the best performance on TPU hardware. Other dtypes are not supported due to TPU MXU constraints and optimization requirements.
- Parameters
dtype – JAX dtype to validate.
- Raises
ValueError – If dtype is not bfloat16 or float32.
Example
>>> assert_is_supported_dtype(jnp.float32) >>> assert_is_supported_dtype(jnp.bfloat16) >>> assert_is_supported_dtype(jnp.float64)
Note
bfloat16: Preferred for TPU v4+, offers 2x throughput vs float32
float32: Universal support, higher precision
Other dtypes (int8, float64, etc.) are not supported
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.is_tpu() bool[source]#
Check if the current JAX backend is running on TPU.
- Returns
True if running on TPU hardware, False otherwise.
- Return type
bool
Example
>>> if is_tpu(): ... print("Running on TPU") ... else: ... print("Not running on TPU")
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.select_input_dtype(lhs: Array, rhs: Array) dtype[source]#
Select the optimal dtype for matrix multiplication inputs.
Determines the best dtype to use for the dot product computation based on: 1. Hardware capabilities (TPU generation) 2. Input tensor dtypes 3. Performance considerations
The function ensures both inputs are cast to a compatible dtype that maximizes performance while maintaining numerical stability.
- Parameters
lhs – Left-hand side matrix for multiplication.
rhs – Right-hand side matrix for multiplication.
- Returns
- The dtype to which both inputs should be cast before
the matrix multiplication. Either jnp.bfloat16 or jnp.float32.
- Return type
jnp.dtype
Example
>>> lhs = jnp.ones((10, 20), dtype=jnp.bfloat16) >>> rhs = jnp.ones((20, 30), dtype=jnp.bfloat16) >>> select_input_dtype(lhs, rhs) dtype('bfloat16')
>>> lhs = jnp.ones((10, 20), dtype=jnp.float32) >>> rhs = jnp.ones((20, 30), dtype=jnp.bfloat16) >>> select_input_dtype(lhs, rhs) dtype('float32')
Note
Uses bfloat16 only if: 1. Hardware supports it (TPU v4+ or CPU/GPU) 2. Both inputs are already bfloat16
Falls back to float32 for mixed precision or older TPUs
This ensures optimal performance without unexpected precision loss
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.supports_bfloat16_matmul() bool[source]#
Check if the current device supports bfloat16 matrix multiplication.
TPU v4 and later generations have native bfloat16 support in their matrix multiplication units (MXUs), providing significant performance benefits over float32 while maintaining numerical stability for many deep learning workloads.
- Returns
- True if the device supports efficient bfloat16 matmul operations,
False otherwise.
- Return type
bool
Example
>>> if supports_bfloat16_matmul(): ... dtype = jnp.bfloat16 ... else: ... dtype = jnp.float32
Note
Returns True for non-TPU devices (CPU/GPU) as they typically support bfloat16 operations, though potentially without hardware acceleration.
Returns True for TPU v4 and later (native MXU support).
Returns False for TPU v2/v3 (limited bfloat16 support).
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.tpu_generation() int[source]#
Extract the generation number of the currently attached TPU.
Parses the TPU device kind string to extract the generation number (e.g., 4 for TPU v4, 5 for TPU v5e/v5p).
- Returns
TPU generation number (e.g., 2, 3, 4, 5).
- Return type
int
- Raises
NotImplementedError – If the device is not a TPU or if the TPU version string format is unrecognized.
Example
>>> tpu_generation() 4 >>> tpu_generation() 5
Note
TPU generations have different capabilities: - TPU v2/v3: Limited bfloat16 support - TPU v4+: Full bfloat16 matmul support - TPU v5: Enhanced performance and memory bandwidth
- ejkernel.kernels._pallas.tpu.grouped_matmul._utils.tpu_kind() str[source]#
Query identification string for the currently attached TPU.
- Returns
TPU device kind string, e.g., “TPU v4”, “TPU v5e”, “TPU v5p”.
- Return type
str
Example
>>> tpu_kind() 'TPU v4'
Note
This function returns the raw device kind string from JAX, which includes the TPU generation and variant information.