ejkernel.ops.utils.datacarrier#

Data carrier classes for kernel configuration parameters.

This module provides dataclasses that encapsulate forward and backward pass parameters for various kernel operations, particularly attention mechanisms. These parameter carriers enable consistent configuration across different kernel implementations and facilitate autotuning by providing hashable parameter sets.

Classes:

FwdParams: Forward pass parameters for kernel configuration BwdParams: Backward pass parameters for kernel configuration

The parameter carriers support:
  • Block size configuration for tiling strategies

  • Warp and pipeline stage configuration for GPU kernels

  • Consistent hashing for configuration caching

  • Optional parameters that can be None for auto-selection

class ejkernel.ops.utils.datacarrier.BwdParams(blocksize_m: int | None = None, blocksize_k: int | None = None, blocksize_n: int | None = None, q_blocksize: int | None = None, kv_blocksize: int | None = None, num_warps: int | None = None, num_stages: int | None = None)[source]#

Bases: object

Backward pass parameters for kernel configuration.

Encapsulates block sizes and execution parameters for backward pass kernels, used in gradient computation for attention and matrix multiplication operations.

blocksize_m#

Block size for M dimension (rows of output matrix)

Type

int | None

blocksize_k#

Block size for K dimension (reduction dimension)

Type

int | None

blocksize_n#

Block size for N dimension (columns of output matrix)

Type

int | None

q_blocksize#

Block size for query dimension in attention gradients

Type

int | None

kv_blocksize#

Block size for key/value dimension in attention gradients

Type

int | None

num_warps#

Number of GPU warps for thread block execution

Type

int | None

num_stages#

Number of pipeline stages for memory optimization

Type

int | None

Note

Parameters are typically smaller than forward pass due to different memory access patterns in gradient computation.

blocksize_k: int | None = None#
blocksize_m: int | None = None#
blocksize_n: int | None = None#
kv_blocksize: int | None = None#
num_stages: int | None = None#
num_warps: int | None = None#
q_blocksize: int | None = None#
class ejkernel.ops.utils.datacarrier.FwdParams(blocksize_m: int | None = None, blocksize_k: int | None = None, blocksize_n: int | None = None, q_blocksize: int | None = None, kv_blocksize: int | None = None, blocksize_heads: int | None = None, blocksize_keys: int | None = None, num_key_splits: int | None = None, num_warps: int | None = None, num_stages: int | None = None)[source]#

Bases: object

Forward pass parameters for kernel configuration.

Encapsulates block sizes and execution parameters for forward pass kernels, particularly for attention and matrix multiplication operations.

blocksize_m#

Block size for M dimension (rows of output matrix)

Type

int | None

blocksize_k#

Block size for K dimension (reduction dimension)

Type

int | None

blocksize_n#

Block size for N dimension (columns of output matrix)

Type

int | None

q_blocksize#

Block size for query dimension in attention

Type

int | None

kv_blocksize#

Block size for key/value dimension in attention

Type

int | None

blocksize_heads#

Block size for head dimension in multi-head attention

Type

int | None

blocksize_keys#

Block size for key sequence length

Type

int | None

num_key_splits#

Number of splits for key computation

Type

int | None

num_warps#

Number of GPU warps for thread block execution

Type

int | None

num_stages#

Number of pipeline stages for memory optimization

Type

int | None

Note

All parameters are optional (None) to allow automatic selection during kernel execution or autotuning.

blocksize_heads: int | None = None#
blocksize_k: int | None = None#
blocksize_keys: int | None = None#
blocksize_m: int | None = None#
blocksize_n: int | None = None#
kv_blocksize: int | None = None#
num_key_splits: int | None = None#
num_stages: int | None = None#
num_warps: int | None = None#
q_blocksize: int | None = None#
ejkernel.ops.utils.datacarrier.get_safe_hash_int(text, algorithm='md5')[source]#

Generate a hash of text using specified algorithm with safety checks.

ejkernel.ops.utils.datacarrier.hash_fn(self) int[source]#

Generate a hash for an object based on its dictionary values.