ejkernel.callib._cute_ffi

ejkernel.callib._cute_ffi#

JAX primitive integration for CuTe DSL kernels via TVM-FFI.

This module provides a Triton-style primitive that performs abstract evaluation from output shape contracts and lowers execution through JAX FFI targets registered by jax_tvm_ffi.

ejkernel.callib._cute_ffi.build_cute_ffi_call(fn: Any, *, output_shape_dtype: Any, input_output_aliases: dict[int, int] | None = None, compile_options: str | None = '--enable-tvm-ffi', **static_kwargs: Any)[source]#

Create a callable that dispatches a CuTe kernel through a JAX primitive.

Parameters
  • fn@cute.jit launcher callable.

  • output_shape_dtype – Output shape/dtype descriptor pytree.

  • input_output_aliases – Optional alias map from flattened input index to flattened output index.

  • compile_options – Optional options passed to cute.compile.

  • **static_kwargs – Static keyword arguments forwarded at compile time.

Returns

A callable that accepts runtime input arrays and returns output arrays.

ejkernel.callib._cute_ffi.has_cute_ffi_support() bool[source]#

Return whether the CuTe TVM-FFI primitive path can be used.