ejkernel.callib._pallas_call

ejkernel.callib._pallas_call#

Buffered Pallas call utilities for TPU kernel execution.

This module provides utilities for creating optimized Pallas kernel calls on TPU with advanced features like double/triple buffering, scalar prefetching, and pipeline emission for hiding memory latency.

Key Features:
  • Scalar memory prefetching from SMEM for fast grid parameter access

  • Configurable input buffering (double/triple) to hide memory latency

  • Automatic pipeline scheduling with emit_pipeline

  • Proper memory space mapping (SMEM, HBM) for TPU memory hierarchy

Functions:

buffered_pallas_call: Create a buffered Pallas call with custom prefetch config

Example

>>> grid_spec = pltpu.PrefetchScalarGridSpec(
...     num_scalar_prefetch=2,
...     in_specs=[lhs_spec, rhs_spec],
...     out_specs=out_spec,
...     grid=(tiles_n, tiles_m, tiles_k),
... )
>>> call_fn = buffered_pallas_call(kernel, out_shape, grid_spec, compiler_params)
>>> result = call_fn(group_metadata, grid_metadata, lhs, rhs)
ejkernel.callib._pallas_call.buffered_pallas_call(kernel: Callable[[...], Any], out_shape: ShapeDtypeStruct, grid_spec: PrefetchScalarGridSpec, compiler_params: CompilerParams, input_buffer_count: collections.abc.Sequence[int] | None = None, **kw)[source]#

Create a buffered Pallas call for TPU with custom prefetch and pipeline configuration.

This function wraps a Pallas kernel with TPU-specific optimizations including: - Scalar memory prefetching for grid parameters - Input buffering for hiding memory latency - Pipeline scheduling with emit_pipeline for overlapping compute and memory ops - Proper memory space mapping (SMEM, HBM) for TPU memory hierarchy

The wrapper handles the complexity of: 1. Binding scalar prefetch values from SMEM to kernel index maps 2. Configuring input buffer counts for double/triple buffering 3. Setting up proper memory spaces for different argument types 4. Integrating with TPU’s emit_pipeline for efficient execution

Parameters
  • kernel – The Pallas kernel function to execute. Should accept: - Scalar prefetch refs (from SMEM) - Input/output refs (data to process) - Scratch refs (temporary workspace)

  • out_shape – Shape and dtype specification for the output tensor(s).

  • grid_spec – PrefetchScalarGridSpec defining: - num_scalar_prefetch: Number of scalar values to prefetch - in_specs: BlockSpec list for input tensors - out_specs: BlockSpec for output tensor(s) - grid: Tuple of grid dimensions (can be static ints or dynamic arrays) - scratch_shapes: Workspace memory shapes

  • compiler_params – TPU compiler parameters including dimension_semantics for specifying parallelization strategy.

  • input_buffer_count – Optional sequence specifying buffer count for each input. Must match length of in_specs. Values > 2 enable multi-buffering. Default is 2 (double buffering) for all inputs.

  • **kw – Additional keyword arguments passed through to pl.pallas_call.

Returns

  • Scalar prefetch arguments

  • Regular input arguments

Return type

A callable that executes the buffered kernel when invoked with

Raises

ValueError – If input_buffer_count length doesn’t match in_specs length.

Example

>>> def my_kernel(group_meta, grid_meta, lhs_ref, rhs_ref, out_ref, scratch):
...
...     pass
>>> grid_spec = pltpu.PrefetchScalarGridSpec(
...     num_scalar_prefetch=2,
...     in_specs=[lhs_spec, rhs_spec],
...     out_specs=out_spec,
...     grid=(tiles_n, tiles_m, tiles_k),
...     scratch_shapes=[pltpu.VMEM((128, 128), jnp.float32)]
... )
>>> call_fn = buffered_pallas_call(
...     my_kernel,
...     out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
...     grid_spec=grid_spec,
...     compiler_params=pltpu.CompilerParams(...),
...     input_buffer_count=[2, 3]
... )
>>> result = call_fn(group_metadata, grid_metadata, lhs, rhs)

Notes

  • Scalar prefetch values are stored in SMEM for fast access

  • Input/output data uses HBM (high-bandwidth memory)

  • Buffer counts > 2 increase memory usage but can hide more latency

  • The wrapper automatically handles grid dimension binding