ejkernel.callib._cute_ffi#
JAX primitive integration for CuTe DSL kernels via TVM-FFI.
This module provides a Triton-style primitive that performs abstract evaluation
from output shape contracts and lowers execution through JAX FFI targets
registered by jax_tvm_ffi.
- ejkernel.callib._cute_ffi.build_cute_ffi_call(fn: Any, *, output_shape_dtype: Any, input_output_aliases: dict[int, int] | None = None, compile_options: str | None = '--enable-tvm-ffi', **static_kwargs: Any)[source]#
Create a callable that dispatches a CuTe kernel through a JAX primitive.
- Parameters
fn –
@cute.jitlauncher callable.output_shape_dtype – Output shape/dtype descriptor pytree.
input_output_aliases – Optional alias map from flattened input index to flattened output index.
compile_options – Optional options passed to
cute.compile.**static_kwargs – Static keyword arguments forwarded at compile time.
- Returns
A callable that accepts runtime input arrays and returns output arrays.