ejkernel.callib._triton_call#

Triton kernel integration with JAX for GPU computation.

This module provides the interface for calling Triton kernels from JAX code, enabling high-performance GPU computation with Triton’s programming model while maintaining JAX’s functional semantics.

Key Features:
  • Seamless Triton kernel invocation from JAX

  • CUDA and ROCm (AMD) GPU support

  • Automatic kernel compilation and caching

  • Support for Triton autotuner and heuristics

  • Input-output aliasing for memory efficiency

  • Configurable launch parameters (warps, stages, CTAs)

Key Components:

triton_call: Main entry point for executing Triton kernels from JAX get_triton_type: Convert JAX/NumPy types to Triton type strings CompilationResult: Container for compiled kernel binary and metadata

Supported Platforms:
  • CUDA: NVIDIA GPUs via PTX compilation

  • ROCm: AMD GPUs via HSACO compilation

Type Conversions:

The module handles automatic type conversion between JAX and Triton: - JAX arrays -> Triton pointers (*bf16, *fp32, *i32, etc.) - Python scalars -> Triton scalars (i32, fp32, etc.) - NumPy arrays -> Triton pointers

Example

>>> import triton
>>> import triton.language as tl
>>> from ejkernel.callib import triton_call
>>>
>>> @triton.jit
... def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
...     pid = tl.program_id(0)
...     offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
...     mask = offsets < n
...     x = tl.load(x_ptr + offsets, mask=mask)
...     y = tl.load(y_ptr + offsets, mask=mask)
...     tl.store(out_ptr + offsets, x + y, mask=mask)
>>>
>>> result = triton_call(
...     x, y,
...     kernel=add_kernel,
...     out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
...     grid=lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),),
...     BLOCK_SIZE=1024,
... )

Note

Automatic differentiation (JVP/VJP) and vmap are not natively supported. Use jax.custom_vjp or jax.custom_vmap for custom gradient/batching rules.

class ejkernel.callib._triton_call.CompilationResult(binary: str, name: str, shared_mem_bytes: int, cluster_dims: tuple, ttgir: str | None, llir: str | None)[source]#

Bases: object

Result of Triton kernel compilation containing binary and metadata.

binary#

Compiled binary code (PTX for CUDA, HSACO path for ROCm).

Type

str

name#

Name of the compiled kernel function.

Type

str

shared_mem_bytes#

Amount of shared memory required in bytes.

Type

int

cluster_dims#

Cluster dimensions for the kernel launch.

Type

tuple

ttgir#

Triton GPU IR representation (optional).

Type

str | None

llir#

LLVM IR representation (optional).

Type

str | None

binary: str#
cluster_dims: tuple#
llir: str | None#
name: str#
shared_mem_bytes: int#
ttgir: str | None#
ejkernel.callib._triton_call.aval_size_bytes(aval)[source]#

Calculate size in bytes for an abstract value.

Parameters

aval – Abstract value with dtype and size attributes.

Returns

Size in bytes as integer.

ejkernel.callib._triton_call.avals_to_layouts(avals)[source]#

Convert abstract values to layout specifications.

Parameters

avals – List of abstract values with ndim attribute.

Returns

List of layout specifications as reversed dimension ranges.

ejkernel.callib._triton_call.compile_ttir_inplace(ttir, backend: [triton.backends.nvidia.compiler.CUDABackend | triton.backends.amd.compiler.HIPBackend], options: [triton.backends.nvidia.compiler.CUDAOptions | triton.backends.amd.compiler.HIPOptions], compute_capability, platform)[source]#

Compile Triton IR to platform-specific binary in-place.

Parameters
  • ttir – Triton IR module to compile.

  • backend – Platform-specific backend (CUDA or HIP).

  • options – Compilation options for the backend.

  • compute_capability – Target compute capability.

  • platform – Target platform (‘cuda’ or ‘rocm’).

Returns

CompilationResult containing compiled binary and metadata.

Raises

ValueError – For unsupported platforms.

ejkernel.callib._triton_call.compile_ttir_to_hsaco_inplace(ttir, hip_backend: HIPBackend, hip_options: HIPOptions, compute_capability) CompilationResult[source]#

Compile Triton IR to HSACO binary for AMD ROCm devices.

Parameters
  • ttir – Triton IR module to compile.

  • hip_backend – HIP compilation backend.

  • hip_options – HIP-specific compilation options.

  • compute_capability – GPU architecture specification.

Returns

CompilationResult with HSACO binary path and metadata.

Raises

ValueError – If compilation passes fail.

ejkernel.callib._triton_call.compile_ttir_to_ptx_inplace(ttir, cuda_backend: CUDABackend, cuda_options: CUDAOptions, compute_capability) CompilationResult[source]#

Compile Triton IR to PTX binary for CUDA devices.

Parameters
  • ttir – Triton IR module to compile.

  • cuda_backend – CUDA compilation backend.

  • cuda_options – CUDA-specific compilation options.

  • compute_capability – CUDA compute capability version.

Returns

CompilationResult with PTX binary and metadata.

Raises

ValueError – If compilation passes fail.

ejkernel.callib._triton_call.get_cuda_backend(device, compute_capability)[source]#

Create CUDA backend for Triton compilation.

Parameters
  • device – CUDA device identifier.

  • compute_capability – CUDA compute capability version.

Returns

CUDABackend instance configured for the device.

ejkernel.callib._triton_call.get_hip_backend(device, compute_capability)[source]#

Create HIP backend for Triton compilation on AMD GPUs.

Parameters
  • device – HIP device identifier.

  • compute_capability – GPU architecture specification.

Returns

HIPBackend instance configured for the device.

ejkernel.callib._triton_call.get_or_create_triton_kernel(backend_init_func, platform, fn, arg_dtypes, scalar_args, device, *, num_warps, num_stages, num_ctas, compute_capability, enable_fp_fusion, metaparams, dump: bool) tuple[triton_kernel_call_lib.TritonKernel, Any][source]#

Get or create a compiled Triton kernel with caching.

Parameters
  • backend_init_func – Function to initialize backend for compilation.

  • platform – Target platform string.

  • fn – Triton JIT function to compile.

  • arg_dtypes – Argument data types.

  • scalar_args – Scalar argument specifications.

  • device – Target device identifier.

  • num_warps – Number of warps per thread block.

  • num_stages – Number of pipeline stages.

  • num_ctas – Number of cooperative thread arrays.

  • compute_capability – Target compute capability.

  • enable_fp_fusion – Whether to enable floating point fusion.

  • metaparams – Kernel metaparameters.

  • dump – Whether to dump debug information.

Returns

Tuple of compiled TritonKernel and specialization attributes.

Raises

ValueError – For unsupported configurations or compilation errors.

ejkernel.callib._triton_call.get_triton_type(obj: Any) str[source]#

Get Triton type string representation for a given object.

Parameters

obj – Object to get type for (ShapedArray, AbstractRef, constexpr, or scalar).

Returns

String representation of the Triton type.

Raises
  • ValueError – For integer overflow cases.

  • NotImplementedError – For unsupported object types.

ejkernel.callib._triton_call.normalize_grid(grid: int | tuple[int] | tuple[int, int] | tuple[int, int, int] | collections.abc.Callable[[dict[str, Any]], int | tuple[int] | tuple[int, int] | tuple[int, int, int]], metaparams) tuple[int, int, int][source]#

Normalize grid specification to a 3D tuple.

Parameters
  • grid – Grid specification as int, tuple, or callable returning grid.

  • metaparams – Dictionary of metaparameters for callable grid evaluation.

Returns

3D tuple representing normalized grid dimensions.

Raises

ValueError – If grid has more than 3 dimensions.

ejkernel.callib._triton_call.triton_call(*args: jax.jaxlib._jax.Array | bool | int | float | numpy.float32, kernel: triton.runtime.jit.JITFunction | triton.runtime.autotuner.Heuristics | triton.runtime.autotuner.Autotuner, out_shape: ejkernel.callib._utils.ShapeDtype | collections.abc.Sequence[ejkernel.callib._utils.ShapeDtype], grid: int | tuple[int] | tuple[int, int] | tuple[int, int, int] | collections.abc.Callable[[dict[str, Any]], int | tuple[int] | tuple[int, int] | tuple[int, int, int]], name: str = '', custom_call_target_name: str = 'triton_kernel_call', num_warps: int | None = None, num_stages: int | None = None, num_ctas: int = 1, device: int = 0, compute_capability: int | None = None, enable_fp_fusion: bool = True, input_output_aliases: dict[int, int] | None = None, zeroed_outputs: collections.abc.Sequence[int] | collections.abc.Callable[[dict[str, Any]], collections.abc.Sequence[int]] = (), debug: bool = False, serialized_metadata: bytes = b'', **metaparams: Any) Any[source]#

Call a Triton kernel from JAX with specified parameters.

This is the main entry point for executing Triton kernels within JAX computations. It handles compilation, optimization, and execution of GPU kernels written in Triton.

Parameters
  • *args – Input arguments to the kernel (arrays and scalars).

  • kernel – Triton kernel function, heuristics, or autotuner to execute.

  • out_shape – Expected output shape(s) and dtype(s).

  • grid – Kernel launch grid specification or callable returning grid.

  • name – Optional name for the kernel call.

  • custom_call_target_name – Target name for the custom call.

  • num_warps – Number of warps per thread block (default: 4).

  • num_stages – Number of pipeline stages (default: 3).

  • num_ctas – Number of cooperative thread arrays.

  • device – Target device identifier.

  • compute_capability – Target compute capability (auto-detected if None).

  • enable_fp_fusion – Whether to enable floating point fusion.

  • input_output_aliases – Mapping of input indices to output indices for aliasing.

  • zeroed_outputs – Indices of outputs to zero-initialize or callable returning them.

  • debug – Whether to enable debug output during compilation.

  • serialized_metadata – Additional serialized kernel metadata.

  • **metaparams – Additional kernel metaparameters.

Returns

JAX array(s) containing the kernel execution results.

Raises

ValueError – If Triton is not installed or for invalid configurations.

ejkernel.callib._triton_call.triton_kernel_call_abstract_eval(*_, out_shapes, **__)[source]#

Abstract evaluation function for triton kernel call primitive.

Parameters
  • *_ – Unused positional arguments.

  • out_shapes – Output shape specifications.

  • ** – Unused keyword arguments.

Returns

List of ShapedArray objects for outputs.

ejkernel.callib._triton_call.triton_kernel_call_lowering(backend_init_func, ctx, *array_args, fn, scalar_args, name, custom_call_target_name, out_shapes, grid, num_warps, num_stages, num_ctas, device, compute_capability, enable_fp_fusion, input_output_aliases, zeroed_outputs, debug, serialized_metadata, **metaparams)[source]#

Lower Triton kernel call to platform-specific implementation.

This function handles the compilation and lowering of Triton kernels for execution, including autotuning support and platform-specific optimizations.

Parameters
  • backend_init_func – Function to initialize the compilation backend.

  • ctx – Lowering context containing type information.

  • *array_args – Array arguments to the kernel.

  • fn – Triton kernel function to execute.

  • scalar_args – Scalar argument specifications.

  • name – Name for the kernel call.

  • custom_call_target_name – Target name for the custom call.

  • out_shapes – Output tensor shapes.

  • grid – Kernel launch grid specification.

  • num_warps – Number of warps per thread block.

  • num_stages – Number of pipeline stages.

  • num_ctas – Number of cooperative thread arrays.

  • device – Target device identifier.

  • compute_capability – Target compute capability.

  • enable_fp_fusion – Whether to enable floating point fusion.

  • input_output_aliases – Input-output aliasing specifications.

  • zeroed_outputs – Outputs to zero-initialize.

  • debug – Whether to enable debug output.

  • serialized_metadata – Serialized kernel metadata.

  • **metaparams – Additional kernel metaparameters.

Returns

Lowered kernel call rule result.

ejkernel.callib._triton_call.triton_kernel_call_raise_on_jvp(*args, **kwargs)[source]#

Raise error for automatic differentiation on Triton kernels.

Parameters
  • *args – Unused positional arguments.

  • **kwargs – Unused keyword arguments.

Raises

NotImplementedError – Always, as JVP is not supported.

ejkernel.callib._triton_call.triton_kernel_call_raise_on_vmap(*args, **kwargs)[source]#

Raise error for batching with vmap on Triton kernels.

Parameters
  • *args – Unused positional arguments.

  • **kwargs – Unused keyword arguments.

Raises

NotImplementedError – Always, as vmap is not supported.