ejkernel.callib._triton_call#
Triton kernel integration with JAX for GPU computation.
This module provides the interface for calling Triton kernels from JAX code, enabling high-performance GPU computation with Triton’s programming model while maintaining JAX’s functional semantics.
- Key Features:
Seamless Triton kernel invocation from JAX
CUDA and ROCm (AMD) GPU support
Automatic kernel compilation and caching
Support for Triton autotuner and heuristics
Input-output aliasing for memory efficiency
Configurable launch parameters (warps, stages, CTAs)
- Key Components:
triton_call: Main entry point for executing Triton kernels from JAX get_triton_type: Convert JAX/NumPy types to Triton type strings CompilationResult: Container for compiled kernel binary and metadata
- Supported Platforms:
CUDA: NVIDIA GPUs via PTX compilation
ROCm: AMD GPUs via HSACO compilation
- Type Conversions:
The module handles automatic type conversion between JAX and Triton: - JAX arrays -> Triton pointers (*bf16, *fp32, *i32, etc.) - Python scalars -> Triton scalars (i32, fp32, etc.) - NumPy arrays -> Triton pointers
Example
>>> import triton
>>> import triton.language as tl
>>> from ejkernel.callib import triton_call
>>>
>>> @triton.jit
... def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
... pid = tl.program_id(0)
... offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
... mask = offsets < n
... x = tl.load(x_ptr + offsets, mask=mask)
... y = tl.load(y_ptr + offsets, mask=mask)
... tl.store(out_ptr + offsets, x + y, mask=mask)
>>>
>>> result = triton_call(
... x, y,
... kernel=add_kernel,
... out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
... grid=lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),),
... BLOCK_SIZE=1024,
... )
Note
Automatic differentiation (JVP/VJP) and vmap are not natively supported. Use jax.custom_vjp or jax.custom_vmap for custom gradient/batching rules.
- class ejkernel.callib._triton_call.CompilationResult(binary: str, name: str, shared_mem_bytes: int, cluster_dims: tuple, ttgir: str | None, llir: str | None)[source]#
Bases:
objectResult of Triton kernel compilation containing binary and metadata.
- binary#
Compiled binary code (PTX for CUDA, HSACO path for ROCm).
- Type
str
- name#
Name of the compiled kernel function.
- Type
str
Amount of shared memory required in bytes.
- Type
int
- cluster_dims#
Cluster dimensions for the kernel launch.
- Type
tuple
- ttgir#
Triton GPU IR representation (optional).
- Type
str | None
- llir#
LLVM IR representation (optional).
- Type
str | None
- binary: str#
- cluster_dims: tuple#
- llir: str | None#
- name: str#
- shared_mem_bytes: int#
- ttgir: str | None#
- ejkernel.callib._triton_call.aval_size_bytes(aval)[source]#
Calculate size in bytes for an abstract value.
- Parameters
aval – Abstract value with dtype and size attributes.
- Returns
Size in bytes as integer.
- ejkernel.callib._triton_call.avals_to_layouts(avals)[source]#
Convert abstract values to layout specifications.
- Parameters
avals – List of abstract values with ndim attribute.
- Returns
List of layout specifications as reversed dimension ranges.
- ejkernel.callib._triton_call.compile_ttir_inplace(ttir, backend: [triton.backends.nvidia.compiler.CUDABackend | triton.backends.amd.compiler.HIPBackend], options: [triton.backends.nvidia.compiler.CUDAOptions | triton.backends.amd.compiler.HIPOptions], compute_capability, platform)[source]#
Compile Triton IR to platform-specific binary in-place.
- Parameters
ttir – Triton IR module to compile.
backend – Platform-specific backend (CUDA or HIP).
options – Compilation options for the backend.
compute_capability – Target compute capability.
platform – Target platform (‘cuda’ or ‘rocm’).
- Returns
CompilationResult containing compiled binary and metadata.
- Raises
ValueError – For unsupported platforms.
- ejkernel.callib._triton_call.compile_ttir_to_hsaco_inplace(ttir, hip_backend: HIPBackend, hip_options: HIPOptions, compute_capability) CompilationResult[source]#
Compile Triton IR to HSACO binary for AMD ROCm devices.
- Parameters
ttir – Triton IR module to compile.
hip_backend – HIP compilation backend.
hip_options – HIP-specific compilation options.
compute_capability – GPU architecture specification.
- Returns
CompilationResult with HSACO binary path and metadata.
- Raises
ValueError – If compilation passes fail.
- ejkernel.callib._triton_call.compile_ttir_to_ptx_inplace(ttir, cuda_backend: CUDABackend, cuda_options: CUDAOptions, compute_capability) CompilationResult[source]#
Compile Triton IR to PTX binary for CUDA devices.
- Parameters
ttir – Triton IR module to compile.
cuda_backend – CUDA compilation backend.
cuda_options – CUDA-specific compilation options.
compute_capability – CUDA compute capability version.
- Returns
CompilationResult with PTX binary and metadata.
- Raises
ValueError – If compilation passes fail.
- ejkernel.callib._triton_call.get_cuda_backend(device, compute_capability)[source]#
Create CUDA backend for Triton compilation.
- Parameters
device – CUDA device identifier.
compute_capability – CUDA compute capability version.
- Returns
CUDABackend instance configured for the device.
- ejkernel.callib._triton_call.get_hip_backend(device, compute_capability)[source]#
Create HIP backend for Triton compilation on AMD GPUs.
- Parameters
device – HIP device identifier.
compute_capability – GPU architecture specification.
- Returns
HIPBackend instance configured for the device.
- ejkernel.callib._triton_call.get_or_create_triton_kernel(backend_init_func, platform, fn, arg_dtypes, scalar_args, device, *, num_warps, num_stages, num_ctas, compute_capability, enable_fp_fusion, metaparams, dump: bool) tuple[triton_kernel_call_lib.TritonKernel, Any][source]#
Get or create a compiled Triton kernel with caching.
- Parameters
backend_init_func – Function to initialize backend for compilation.
platform – Target platform string.
fn – Triton JIT function to compile.
arg_dtypes – Argument data types.
scalar_args – Scalar argument specifications.
device – Target device identifier.
num_warps – Number of warps per thread block.
num_stages – Number of pipeline stages.
num_ctas – Number of cooperative thread arrays.
compute_capability – Target compute capability.
enable_fp_fusion – Whether to enable floating point fusion.
metaparams – Kernel metaparameters.
dump – Whether to dump debug information.
- Returns
Tuple of compiled TritonKernel and specialization attributes.
- Raises
ValueError – For unsupported configurations or compilation errors.
- ejkernel.callib._triton_call.get_triton_type(obj: Any) str[source]#
Get Triton type string representation for a given object.
- Parameters
obj – Object to get type for (ShapedArray, AbstractRef, constexpr, or scalar).
- Returns
String representation of the Triton type.
- Raises
ValueError – For integer overflow cases.
NotImplementedError – For unsupported object types.
- ejkernel.callib._triton_call.normalize_grid(grid: int | tuple[int] | tuple[int, int] | tuple[int, int, int] | collections.abc.Callable[[dict[str, Any]], int | tuple[int] | tuple[int, int] | tuple[int, int, int]], metaparams) tuple[int, int, int][source]#
Normalize grid specification to a 3D tuple.
- Parameters
grid – Grid specification as int, tuple, or callable returning grid.
metaparams – Dictionary of metaparameters for callable grid evaluation.
- Returns
3D tuple representing normalized grid dimensions.
- Raises
ValueError – If grid has more than 3 dimensions.
- ejkernel.callib._triton_call.triton_call(*args: jax.jaxlib._jax.Array | bool | int | float | numpy.float32, kernel: triton.runtime.jit.JITFunction | triton.runtime.autotuner.Heuristics | triton.runtime.autotuner.Autotuner, out_shape: ejkernel.callib._utils.ShapeDtype | collections.abc.Sequence[ejkernel.callib._utils.ShapeDtype], grid: int | tuple[int] | tuple[int, int] | tuple[int, int, int] | collections.abc.Callable[[dict[str, Any]], int | tuple[int] | tuple[int, int] | tuple[int, int, int]], name: str = '', custom_call_target_name: str = 'triton_kernel_call', num_warps: int | None = None, num_stages: int | None = None, num_ctas: int = 1, device: int = 0, compute_capability: int | None = None, enable_fp_fusion: bool = True, input_output_aliases: dict[int, int] | None = None, zeroed_outputs: collections.abc.Sequence[int] | collections.abc.Callable[[dict[str, Any]], collections.abc.Sequence[int]] = (), debug: bool = False, serialized_metadata: bytes = b'', **metaparams: Any) Any[source]#
Call a Triton kernel from JAX with specified parameters.
This is the main entry point for executing Triton kernels within JAX computations. It handles compilation, optimization, and execution of GPU kernels written in Triton.
- Parameters
*args – Input arguments to the kernel (arrays and scalars).
kernel – Triton kernel function, heuristics, or autotuner to execute.
out_shape – Expected output shape(s) and dtype(s).
grid – Kernel launch grid specification or callable returning grid.
name – Optional name for the kernel call.
custom_call_target_name – Target name for the custom call.
num_warps – Number of warps per thread block (default: 4).
num_stages – Number of pipeline stages (default: 3).
num_ctas – Number of cooperative thread arrays.
device – Target device identifier.
compute_capability – Target compute capability (auto-detected if None).
enable_fp_fusion – Whether to enable floating point fusion.
input_output_aliases – Mapping of input indices to output indices for aliasing.
zeroed_outputs – Indices of outputs to zero-initialize or callable returning them.
debug – Whether to enable debug output during compilation.
serialized_metadata – Additional serialized kernel metadata.
**metaparams – Additional kernel metaparameters.
- Returns
JAX array(s) containing the kernel execution results.
- Raises
ValueError – If Triton is not installed or for invalid configurations.
- ejkernel.callib._triton_call.triton_kernel_call_abstract_eval(*_, out_shapes, **__)[source]#
Abstract evaluation function for triton kernel call primitive.
- Parameters
*_ – Unused positional arguments.
out_shapes – Output shape specifications.
** – Unused keyword arguments.
- Returns
List of ShapedArray objects for outputs.
- ejkernel.callib._triton_call.triton_kernel_call_lowering(backend_init_func, ctx, *array_args, fn, scalar_args, name, custom_call_target_name, out_shapes, grid, num_warps, num_stages, num_ctas, device, compute_capability, enable_fp_fusion, input_output_aliases, zeroed_outputs, debug, serialized_metadata, **metaparams)[source]#
Lower Triton kernel call to platform-specific implementation.
This function handles the compilation and lowering of Triton kernels for execution, including autotuning support and platform-specific optimizations.
- Parameters
backend_init_func – Function to initialize the compilation backend.
ctx – Lowering context containing type information.
*array_args – Array arguments to the kernel.
fn – Triton kernel function to execute.
scalar_args – Scalar argument specifications.
name – Name for the kernel call.
custom_call_target_name – Target name for the custom call.
out_shapes – Output tensor shapes.
grid – Kernel launch grid specification.
num_warps – Number of warps per thread block.
num_stages – Number of pipeline stages.
num_ctas – Number of cooperative thread arrays.
device – Target device identifier.
compute_capability – Target compute capability.
enable_fp_fusion – Whether to enable floating point fusion.
input_output_aliases – Input-output aliasing specifications.
zeroed_outputs – Outputs to zero-initialize.
debug – Whether to enable debug output.
serialized_metadata – Serialized kernel metadata.
**metaparams – Additional kernel metaparameters.
- Returns
Lowered kernel call rule result.