ejkernel.callib._pallas_call#
Buffered Pallas call utilities for TPU kernel execution.
This module provides utilities for creating optimized Pallas kernel calls on TPU with advanced features like double/triple buffering, scalar prefetching, and pipeline emission for hiding memory latency.
- Key Features:
Scalar memory prefetching from SMEM for fast grid parameter access
Configurable input buffering (double/triple) to hide memory latency
Automatic pipeline scheduling with emit_pipeline
Proper memory space mapping (SMEM, HBM) for TPU memory hierarchy
- Functions:
buffered_pallas_call: Create a buffered Pallas call with custom prefetch config
Example
>>> grid_spec = pltpu.PrefetchScalarGridSpec(
... num_scalar_prefetch=2,
... in_specs=[lhs_spec, rhs_spec],
... out_specs=out_spec,
... grid=(tiles_n, tiles_m, tiles_k),
... )
>>> call_fn = buffered_pallas_call(kernel, out_shape, grid_spec, compiler_params)
>>> result = call_fn(group_metadata, grid_metadata, lhs, rhs)
- ejkernel.callib._pallas_call.buffered_pallas_call(kernel: Callable[[...], Any], out_shape: ShapeDtypeStruct, grid_spec: PrefetchScalarGridSpec, compiler_params: CompilerParams, input_buffer_count: collections.abc.Sequence[int] | None = None, **kw)[source]#
Create a buffered Pallas call for TPU with custom prefetch and pipeline configuration.
This function wraps a Pallas kernel with TPU-specific optimizations including: - Scalar memory prefetching for grid parameters - Input buffering for hiding memory latency - Pipeline scheduling with emit_pipeline for overlapping compute and memory ops - Proper memory space mapping (SMEM, HBM) for TPU memory hierarchy
The wrapper handles the complexity of: 1. Binding scalar prefetch values from SMEM to kernel index maps 2. Configuring input buffer counts for double/triple buffering 3. Setting up proper memory spaces for different argument types 4. Integrating with TPU’s emit_pipeline for efficient execution
- Parameters
kernel – The Pallas kernel function to execute. Should accept: - Scalar prefetch refs (from SMEM) - Input/output refs (data to process) - Scratch refs (temporary workspace)
out_shape – Shape and dtype specification for the output tensor(s).
grid_spec – PrefetchScalarGridSpec defining: - num_scalar_prefetch: Number of scalar values to prefetch - in_specs: BlockSpec list for input tensors - out_specs: BlockSpec for output tensor(s) - grid: Tuple of grid dimensions (can be static ints or dynamic arrays) - scratch_shapes: Workspace memory shapes
compiler_params – TPU compiler parameters including dimension_semantics for specifying parallelization strategy.
input_buffer_count – Optional sequence specifying buffer count for each input. Must match length of in_specs. Values > 2 enable multi-buffering. Default is 2 (double buffering) for all inputs.
**kw – Additional keyword arguments passed through to pl.pallas_call.
- Returns
Scalar prefetch arguments
Regular input arguments
- Return type
A callable that executes the buffered kernel when invoked with
- Raises
ValueError – If input_buffer_count length doesn’t match in_specs length.
Example
>>> def my_kernel(group_meta, grid_meta, lhs_ref, rhs_ref, out_ref, scratch): ... ... pass >>> grid_spec = pltpu.PrefetchScalarGridSpec( ... num_scalar_prefetch=2, ... in_specs=[lhs_spec, rhs_spec], ... out_specs=out_spec, ... grid=(tiles_n, tiles_m, tiles_k), ... scratch_shapes=[pltpu.VMEM((128, 128), jnp.float32)] ... ) >>> call_fn = buffered_pallas_call( ... my_kernel, ... out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ... grid_spec=grid_spec, ... compiler_params=pltpu.CompilerParams(...), ... input_buffer_count=[2, 3] ... ) >>> result = call_fn(group_metadata, grid_metadata, lhs, rhs)
Notes
Scalar prefetch values are stored in SMEM for fast access
Input/output data uses HBM (high-bandwidth memory)
Buffer counts > 2 increase memory usage but can hide more latency
The wrapper automatically handles grid dimension binding