ejkernel.callib._cute_call

Contents

ejkernel.callib._cute_call#

CuTe DSL kernel integration helpers for JAX.

The wrapper mirrors the placement checks used by ejkernel.callib._triton_call. It enforces that all array arguments are on one device, and when multiple accelerators are present it requires execution under jax.shard_map.

ejkernel.callib._cute_call.cute_call(*args: Any, call: collections.abc.Callable[[...], Any] | None = None, out_shape: ejkernel.callib._utils.ShapeDtype | collections.abc.Sequence[ejkernel.callib._utils.ShapeDtype] | None = None, out: Any | None = None, name: str | None = None, device: int | None = None) Any[source]#

Execute a CuTe DSL kernel and return its output(s).

The callable is expected to return output arrays directly. Callers can pass out_shape or out to define/validate the expected output contract.

The provided call must be primitive-backed and return output arrays. out is accepted for compatibility and treated as output metadata contract (shape/dtype/tree), not as a destination buffer.

Parameters
  • *args – Positional arguments forwarded to the CuTe kernel callable.

  • call – The CuTe kernel callable to execute. Must be a primitive-backed function that returns output arrays (e.g. from build_cute_ffi_call).

  • out_shape – Expected output shape/dtype specification(s) used to build the output contract. Can be a single ShapeDtype or a sequence.

  • out – Optional explicit output array(s) whose shape/dtype/tree are used as the output contract. Not used as a destination buffer.

  • name – Optional name for the kernel call, used for JAX named scopes and internal caching.

  • device – Optional device index to validate input placement against.

Returns

The output(s) produced by the CuTe kernel, unflattened to match the out or out_shape pytree structure.

Raises
  • ValueError – If CuTe is not installed, call is None, neither out nor out_shape is provided, or the callable returns None.

  • AssertionError – If array arguments span multiple devices, or if multiple accelerators are detected without an active jax.shard_map context.