ejkernel.kernels._pallas.tpu.grouped_matmul._utils#

Utility functions for TPU device detection and dtype validation.

This module provides helper functions to detect TPU hardware capabilities and validate data types for grouped matrix multiplication operations. It handles TPU generation detection and dtype compatibility checks.

ejkernel.kernels._pallas.tpu.grouped_matmul._utils.assert_is_supported_dtype(dtype: Union[str, type[Any], dtype, SupportsDType]) None[source]#

Validate that a dtype is supported for grouped matrix multiplication.

The grouped matmul kernels are optimized for bfloat16 and float32 dtypes, which provide the best performance on TPU hardware. Other dtypes are not supported due to TPU MXU constraints and optimization requirements.

Parameters

dtype – JAX dtype to validate.

Raises

ValueError – If dtype is not bfloat16 or float32.

Example

>>> assert_is_supported_dtype(jnp.float32)
>>> assert_is_supported_dtype(jnp.bfloat16)
>>> assert_is_supported_dtype(jnp.float64)

Note

  • bfloat16: Preferred for TPU v4+, offers 2x throughput vs float32

  • float32: Universal support, higher precision

  • Other dtypes (int8, float64, etc.) are not supported

ejkernel.kernels._pallas.tpu.grouped_matmul._utils.is_tpu() bool[source]#

Check if the current JAX backend is running on TPU.

Returns

True if running on TPU hardware, False otherwise.

Return type

bool

Example

>>> if is_tpu():
...     print("Running on TPU")
... else:
...     print("Not running on TPU")
ejkernel.kernels._pallas.tpu.grouped_matmul._utils.select_input_dtype(lhs: Array, rhs: Array) dtype[source]#

Select the optimal dtype for matrix multiplication inputs.

Determines the best dtype to use for the dot product computation based on: 1. Hardware capabilities (TPU generation) 2. Input tensor dtypes 3. Performance considerations

The function ensures both inputs are cast to a compatible dtype that maximizes performance while maintaining numerical stability.

Parameters
  • lhs – Left-hand side matrix for multiplication.

  • rhs – Right-hand side matrix for multiplication.

Returns

The dtype to which both inputs should be cast before

the matrix multiplication. Either jnp.bfloat16 or jnp.float32.

Return type

jnp.dtype

Example

>>> lhs = jnp.ones((10, 20), dtype=jnp.bfloat16)
>>> rhs = jnp.ones((20, 30), dtype=jnp.bfloat16)
>>> select_input_dtype(lhs, rhs)
dtype('bfloat16')
>>> lhs = jnp.ones((10, 20), dtype=jnp.float32)
>>> rhs = jnp.ones((20, 30), dtype=jnp.bfloat16)
>>> select_input_dtype(lhs, rhs)
dtype('float32')

Note

  • Uses bfloat16 only if: 1. Hardware supports it (TPU v4+ or CPU/GPU) 2. Both inputs are already bfloat16

  • Falls back to float32 for mixed precision or older TPUs

  • This ensures optimal performance without unexpected precision loss

ejkernel.kernels._pallas.tpu.grouped_matmul._utils.supports_bfloat16_matmul() bool[source]#

Check if the current device supports bfloat16 matrix multiplication.

TPU v4 and later generations have native bfloat16 support in their matrix multiplication units (MXUs), providing significant performance benefits over float32 while maintaining numerical stability for many deep learning workloads.

Returns

True if the device supports efficient bfloat16 matmul operations,

False otherwise.

Return type

bool

Example

>>> if supports_bfloat16_matmul():
...     dtype = jnp.bfloat16
... else:
...     dtype = jnp.float32

Note

  • Returns True for non-TPU devices (CPU/GPU) as they typically support bfloat16 operations, though potentially without hardware acceleration.

  • Returns True for TPU v4 and later (native MXU support).

  • Returns False for TPU v2/v3 (limited bfloat16 support).

ejkernel.kernels._pallas.tpu.grouped_matmul._utils.tpu_generation() int[source]#

Extract the generation number of the currently attached TPU.

Parses the TPU device kind string to extract the generation number (e.g., 4 for TPU v4, 5 for TPU v5e/v5p).

Returns

TPU generation number (e.g., 2, 3, 4, 5).

Return type

int

Raises

NotImplementedError – If the device is not a TPU or if the TPU version string format is unrecognized.

Example

>>> tpu_generation()
4
>>> tpu_generation()
5

Note

TPU generations have different capabilities: - TPU v2/v3: Limited bfloat16 support - TPU v4+: Full bfloat16 matmul support - TPU v5: Enhanced performance and memory bandwidth

ejkernel.kernels._pallas.tpu.grouped_matmul._utils.tpu_kind() str[source]#

Query identification string for the currently attached TPU.

Returns

TPU device kind string, e.g., “TPU v4”, “TPU v5e”, “TPU v5p”.

Return type

str

Example

>>> tpu_kind()
'TPU v4'

Note

This function returns the raw device kind string from JAX, which includes the TPU generation and variant information.