# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX primitive integration for CuTe DSL kernels via TVM-FFI.
This module provides a Triton-style primitive that performs abstract evaluation
from output shape contracts and lowers execution through JAX FFI targets
registered by ``jax_tvm_ffi``.
"""
from __future__ import annotations
import functools
import hashlib
import inspect
import threading
from dataclasses import dataclass
from functools import partial
from typing import Any
import jax
import jax.extend as jex
import jax.numpy as jnp
from jax import tree_util
from jax.interpreters import ad, batching, mlir, xla
_HAS_CUDA_BINDINGS = False
try:
import cuda.bindings.driver as cuda
_HAS_CUDA_BINDINGS = True
except Exception:
cuda = None
_HAS_JAX_TVM_FFI = False
try:
import jax_tvm_ffi
_HAS_JAX_TVM_FFI = True
except Exception:
jax_tvm_ffi = None
CAN_USE_CUTE_PRIMITIVE = False
try:
import cutlass
import cutlass.cute as cute
CAN_USE_CUTE_PRIMITIVE = True
except ModuleNotFoundError:
pass
if CAN_USE_CUTE_PRIMITIVE:
_DTYPE_TO_CUTLASS: dict[jnp.dtype, type[cutlass.Numeric]] = {
jnp.dtype(jnp.float16): cutlass.Float16,
jnp.dtype(jnp.bfloat16): cutlass.BFloat16,
jnp.dtype(jnp.float32): cutlass.Float32,
jnp.dtype(jnp.int8): cutlass.Int8,
jnp.dtype(jnp.uint8): cutlass.Uint8,
jnp.dtype(jnp.int32): cutlass.Int32,
jnp.dtype(jnp.uint32): cutlass.Uint32,
}
else: # pragma: no cover
_DTYPE_TO_CUTLASS = {}
@dataclass(frozen=True)
class _CompiledKernel:
"""Container for a compiled CuTe callable and its FFI target name.
Attributes:
target_name: The unique FFI target name used to register the kernel
with JAX via TVM-FFI.
compiled: The compiled CuTe callable object produced by
``cute.compile``.
"""
target_name: str
compiled: Any
_COMPILE_CACHE: dict[tuple[Any, ...], _CompiledKernel] = {}
_COMPILE_LOCK = threading.Lock()
_REGISTERED_TARGETS: set[str] = set()
_REGISTERED_TARGETS_LOCK = threading.Lock()
def _to_shape_dtype_struct(out_shape: Any) -> Any:
"""Normalize output descriptors into ``jax.ShapeDtypeStruct`` leaves.
Args:
out_shape: A pytree of objects with ``shape`` and ``dtype`` attributes.
Returns:
A pytree with the same structure where each leaf is replaced by a
``jax.ShapeDtypeStruct``.
"""
return tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape)
def _shape_dtype_key(shaped: Any) -> tuple[int, jnp.dtype]:
"""Build a stable compile-cache key fragment from a shaped value.
Args:
shaped: An object with ``shape`` and ``dtype`` attributes.
Returns:
A tuple of ``(rank, dtype)`` suitable for use as a cache key component.
"""
return (len(tuple(shaped.shape)), jnp.dtype(shaped.dtype))
def _fake_tensor_from_shaped(shaped: Any):
"""Create a fake compact tensor from shape/dtype metadata.
Uses ``assumed_align=16`` (16 bytes = 128 bits) so that CuTe DSL
kernels using ``CopyG2SOp`` with 128-bit cp.async copies can pass
MLIR alignment verification. JAX/XLA guarantees at least 128-byte
alignment for device buffers, so this is always safe.
"""
dtype = jnp.dtype(shaped.dtype)
cutlass_dtype = _DTYPE_TO_CUTLASS.get(dtype)
if cutlass_dtype is None:
raise TypeError(f"Unsupported dtype for CuTe primitive path: {dtype}")
rank = len(tuple(shaped.shape))
sym_shape = tuple(cute.sym_int() for _ in range(rank))
stride_order = tuple(range(rank - 1, -1, -1))
return cute.runtime.make_fake_compact_tensor(
cutlass_dtype,
sym_shape,
stride_order=stride_order,
assumed_align=16,
)
def _fn_expects_stream(fn: Any) -> bool:
"""Return whether the compiled host launcher expects a leading stream arg.
Inspects the function's signature to determine if the first positional
parameter is named ``stream``.
Args:
fn: A callable whose signature will be inspected.
Returns:
``True`` if the first positional parameter is named ``stream``,
``False`` otherwise or if the signature cannot be inspected.
"""
try:
params = list(inspect.signature(fn).parameters.values())
except Exception:
return False
if not params:
return False
first = params[0]
if first.kind not in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
return False
return first.name == "stream"
def _make_fake_stream() -> Any:
"""Create a compile-time stream placeholder for TVM-FFI CuTe kernels.
Tries ``cute.runtime.make_fake_stream`` first, then falls back to a
null CUDA stream from the ``cuda.bindings`` package.
Returns:
A fake or null stream object suitable for CuTe compilation.
Raises:
RuntimeError: If neither ``cute.runtime`` nor ``cuda.bindings``
can provide a stream object.
"""
try:
return cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
except Exception:
pass
if _HAS_CUDA_BINDINGS:
return cuda.CUstream(0)
raise RuntimeError("Unable to create a CuTe fake stream for TVM-FFI compilation.")
def _cache_key_hash(cache_key: tuple[Any, ...]) -> str:
"""Build a deterministic hash string from a compile cache key.
Args:
cache_key: A tuple of values representing the compilation parameters.
Returns:
A SHA-256 hex-digest string derived from the cache key's ``repr``.
"""
return hashlib.sha256(repr(cache_key).encode("utf-8")).hexdigest()
def _compile_or_get_kernel(
*,
fn: Any,
in_shaped: tuple[Any, ...],
out_shaped: tuple[jax.ShapeDtypeStruct, ...],
compile_options: str | None,
static_kwargs: tuple[tuple[str, Any], ...],
) -> _CompiledKernel:
"""Compile (or fetch cached) CuTe callable and FFI target metadata.
Looks up the compile cache by a key derived from the function identity,
input/output shapes and dtypes, compile options, and static kwargs. On a
cache miss, compiles the kernel using fake tensors and stores the result.
Args:
fn: The ``@cute.jit`` launcher callable to compile.
in_shaped: Tuple of input shaped objects (with ``shape`` and ``dtype``).
out_shaped: Tuple of ``jax.ShapeDtypeStruct`` for expected outputs.
compile_options: Optional options string forwarded to ``cute.compile``.
static_kwargs: Tuple of ``(name, value)`` pairs for static keyword
arguments passed at compile time.
Returns:
A ``_CompiledKernel`` containing the compiled callable and its
unique FFI target name.
"""
cache_key = (
fn,
tuple(_shape_dtype_key(arg) for arg in in_shaped),
tuple(_shape_dtype_key(arg) for arg in out_shaped),
compile_options,
static_kwargs,
)
with _COMPILE_LOCK:
cached = _COMPILE_CACHE.get(cache_key)
if cached is not None:
return cached
in_fake = [_fake_tensor_from_shaped(arg) for arg in in_shaped]
out_fake = [_fake_tensor_from_shaped(arg) for arg in out_shaped]
expects_stream = _fn_expects_stream(fn)
compile_kwargs = dict(static_kwargs)
def _compile(add_stream: bool):
"""Run ``cute.compile`` with or without a leading stream argument.
Args:
add_stream: If ``True``, prepend a fake stream to the
compile arguments.
Returns:
The compiled CuTe callable.
"""
compile_args: list[Any] = []
if add_stream:
compile_args.append(_make_fake_stream())
compile_args.extend(in_fake)
compile_args.extend(out_fake)
if compile_options:
return cute.compile(fn, *compile_args, options=compile_options, **compile_kwargs)
return cute.compile(fn, *compile_args, **compile_kwargs)
try:
compiled = _compile(expects_stream)
except Exception:
# Some wrapped launchers do not expose a stable python signature.
# Retry with flipped stream assumption before surfacing the error.
compiled = _compile(not expects_stream)
digest = _cache_key_hash(cache_key)[:24]
target_name = f"ejkernel_cute_tvm_ffi_{digest}"
result = _CompiledKernel(target_name=target_name, compiled=compiled)
_COMPILE_CACHE[cache_key] = result
return result
def _register_target_once(kernel: _CompiledKernel) -> None:
"""Register a compiled CuTe callable as a JAX FFI target exactly once.
Uses ``jax_tvm_ffi.register_ffi_target`` to make the compiled kernel
available to the JAX FFI lowering path. Tries multiple platform strings
(``gpu``, ``cuda``, ``CUDA``, and unspecified) to accommodate different
JAX/XLA runtime configurations.
Args:
kernel: The ``_CompiledKernel`` to register.
Raises:
ValueError: If ``jax_tvm_ffi`` is not installed.
RuntimeError: If all registration attempts fail.
"""
if not _HAS_JAX_TVM_FFI:
raise ValueError("CuTe primitive path requires `jax_tvm_ffi` (apache-tvm-ffi) to register TVM-FFI targets.")
with _REGISTERED_TARGETS_LOCK:
if kernel.target_name in _REGISTERED_TARGETS:
return
error: Exception | None = None
register_fns = (
lambda: jax_tvm_ffi.register_ffi_target(kernel.target_name, kernel.compiled, platform="gpu"),
lambda: jax_tvm_ffi.register_ffi_target(kernel.target_name, kernel.compiled, platform="cuda"),
lambda: jax_tvm_ffi.register_ffi_target(kernel.target_name, kernel.compiled, platform="CUDA"),
lambda: jax_tvm_ffi.register_ffi_target(kernel.target_name, kernel.compiled),
)
for register_fn in register_fns:
try:
register_fn()
_REGISTERED_TARGETS.add(kernel.target_name)
return
except Exception as exc: # pragma: no cover - exercised only in incompatible runtime envs
error = exc
raise RuntimeError(f"Failed to register CuTe TVM-FFI target `{kernel.target_name}`.") from error
def _cute_kernel_call_impl(
*args_flat,
fn: Any,
out_shape_dtype_flat: tuple[jax.ShapeDtypeStruct, ...],
input_output_aliases: tuple[tuple[int, int], ...],
compile_options: str | None,
static_kwargs: tuple[tuple[str, Any], ...],
):
"""Primitive implementation shared by eager and lowering paths.
Compiles (or retrieves from cache) the CuTe kernel, registers it as a
JAX FFI target, and dispatches execution via ``jax.ffi.ffi_call``.
Args:
*args_flat: Flattened input arrays for the kernel.
fn: The ``@cute.jit`` launcher callable.
out_shape_dtype_flat: Tuple of ``jax.ShapeDtypeStruct`` for outputs.
input_output_aliases: Tuple of ``(input_idx, output_idx)`` pairs for
in-place aliasing.
compile_options: Optional options string for ``cute.compile``.
static_kwargs: Tuple of ``(name, value)`` pairs for static kwargs.
Returns:
The output arrays produced by the FFI call.
Raises:
ValueError: If CuTe is not available.
"""
if not CAN_USE_CUTE_PRIMITIVE:
raise ValueError("CuTe primitive path requires CUTLASS CuTe DSL.")
kernel = _compile_or_get_kernel(
fn=fn,
in_shaped=tuple(args_flat),
out_shaped=out_shape_dtype_flat,
compile_options=compile_options,
static_kwargs=static_kwargs,
)
_register_target_once(kernel)
alias_map = dict(input_output_aliases)
ffi_call = jax.ffi.ffi_call(
kernel.target_name,
result_shape_dtypes=out_shape_dtype_flat,
input_output_aliases=alias_map,
)
return ffi_call(*args_flat)
cute_kernel_call_p = jex.core.Primitive("ejkernel_cute_kernel_call")
cute_kernel_call_p.multiple_results = True
cute_kernel_call_p.def_impl(functools.partial(xla.apply_primitive, cute_kernel_call_p))
@cute_kernel_call_p.def_abstract_eval
def _cute_kernel_call_abstract_eval(*_, out_shape_dtype_flat, **__):
"""Primitive abstract evaluation returning output avals.
Args:
*_: Unused positional arguments (input avals).
out_shape_dtype_flat: Tuple of ``jax.ShapeDtypeStruct`` defining
the expected output shapes and dtypes.
**__: Unused keyword arguments.
Returns:
List of ``jax.core.ShapedArray`` abstract values for each output.
"""
return [jax.core.ShapedArray(x.shape, x.dtype) for x in out_shape_dtype_flat]
def _raise_on_jvp(*args, **kwargs):
"""Raise for unsupported automatic differentiation.
Args:
*args: Unused positional arguments.
**kwargs: Unused keyword arguments.
Raises:
NotImplementedError: Always, as the CuTe TVM-FFI primitive does
not support JVP or transpose rules.
"""
del args, kwargs
raise NotImplementedError(
"CuTe TVM-FFI primitive does not support automatic differentiation. Use `jax.custom_jvp` or `jax.custom_vjp`."
)
def _raise_on_vmap(*args, **kwargs):
"""Raise for unsupported batching.
Args:
*args: Unused positional arguments.
**kwargs: Unused keyword arguments.
Raises:
NotImplementedError: Always, as the CuTe TVM-FFI primitive does
not support batching via ``jax.vmap``.
"""
del args, kwargs
raise NotImplementedError(
"CuTe TVM-FFI primitive does not support batching via `jax.vmap`. Use `jax.custom_batching.custom_vmap`."
)
ad.primitive_jvps[cute_kernel_call_p] = _raise_on_jvp
ad.primitive_transposes[cute_kernel_call_p] = _raise_on_jvp
batching.primitive_batchers[cute_kernel_call_p] = _raise_on_vmap
mlir.register_lowering(
cute_kernel_call_p,
mlir.lower_fun(_cute_kernel_call_impl, multiple_results=True),
platform="cuda",
)
[docs]def build_cute_ffi_call(
fn: Any,
*,
output_shape_dtype: Any,
input_output_aliases: dict[int, int] | None = None,
compile_options: str | None = "--enable-tvm-ffi",
**static_kwargs: Any,
):
"""Create a callable that dispatches a CuTe kernel through a JAX primitive.
Args:
fn: ``@cute.jit`` launcher callable.
output_shape_dtype: Output shape/dtype descriptor pytree.
input_output_aliases: Optional alias map from flattened input index
to flattened output index.
compile_options: Optional options passed to ``cute.compile``.
**static_kwargs: Static keyword arguments forwarded at compile time.
Returns:
A callable that accepts runtime input arrays and returns output arrays.
"""
out_shape = _to_shape_dtype_struct(output_shape_dtype)
flat_out_shape, out_tree = tree_util.tree_flatten(out_shape)
static_items = tuple(static_kwargs.items())
alias_items = tuple(sorted((input_output_aliases or {}).items()))
@partial(jax.jit, inline=True)
def _call(*args):
"""Dispatch the CuTe kernel through the JAX primitive.
Args:
*args: Runtime input arrays (pytree-flattened internally).
Returns:
Output pytree matching the ``output_shape_dtype`` structure.
"""
args_flat, _ = tree_util.tree_flatten(args)
out_flat = cute_kernel_call_p.bind(
*args_flat,
fn=fn,
out_shape_dtype_flat=tuple(flat_out_shape),
input_output_aliases=alias_items,
compile_options=compile_options,
static_kwargs=static_items,
)
return tree_util.tree_unflatten(out_tree, out_flat)
return _call
[docs]def has_cute_ffi_support() -> bool:
"""Return whether the CuTe TVM-FFI primitive path can be used."""
if not CAN_USE_CUTE_PRIMITIVE or not _HAS_JAX_TVM_FFI:
return False
return True
__all__ = ["build_cute_ffi_call", "has_cute_ffi_support"]