ejkernel.callib._cute_call#
CuTe DSL kernel integration helpers for JAX.
The wrapper mirrors the placement checks used by ejkernel.callib._triton_call.
It enforces that all array arguments are on one device, and when multiple
accelerators are present it requires execution under jax.shard_map.
- ejkernel.callib._cute_call.cute_call(*args: Any, call: collections.abc.Callable[[...], Any] | None = None, out_shape: ejkernel.callib._utils.ShapeDtype | collections.abc.Sequence[ejkernel.callib._utils.ShapeDtype] | None = None, out: Any | None = None, name: str | None = None, device: int | None = None) Any[source]#
Execute a CuTe DSL kernel and return its output(s).
The callable is expected to return output arrays directly. Callers can pass
out_shapeoroutto define/validate the expected output contract.The provided
callmust be primitive-backed and return output arrays.outis accepted for compatibility and treated as output metadata contract (shape/dtype/tree), not as a destination buffer.- Parameters
*args – Positional arguments forwarded to the CuTe kernel callable.
call – The CuTe kernel callable to execute. Must be a primitive-backed function that returns output arrays (e.g. from
build_cute_ffi_call).out_shape – Expected output shape/dtype specification(s) used to build the output contract. Can be a single
ShapeDtypeor a sequence.out – Optional explicit output array(s) whose shape/dtype/tree are used as the output contract. Not used as a destination buffer.
name – Optional name for the kernel call, used for JAX named scopes and internal caching.
device – Optional device index to validate input placement against.
- Returns
The output(s) produced by the CuTe kernel, unflattened to match the
outorout_shapepytree structure.- Raises
ValueError – If CuTe is not installed,
callisNone, neitheroutnorout_shapeis provided, or the callable returnsNone.AssertionError – If array arguments span multiple devices, or if multiple accelerators are detected without an active
jax.shard_mapcontext.