ejkernel.modules.operations.configs

Contents

ejkernel.modules.operations.configs#

Operation-specific configuration classes.

This module defines configuration dataclasses for each attention operation, providing type-safe, operation-specific parameters for kernel execution and autotuning.

class ejkernel.modules.operations.configs.AttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2)[source]#

Bases: BaseOperationConfig

Configuration for basic Attention operation.

Parameters
  • block_q – Query block size (default: 128)

  • block_k – Key block size (default: 128)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 2)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_k: int = 128#
block_q: int = 128#
num_stages: int = 2#
num_warps: int = 4#
class ejkernel.modules.operations.configs.BaseOperationConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#

Bases: object

Base configuration for all operations.

backend: str = 'any'#
platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto'#
class ejkernel.modules.operations.configs.BlockSparseAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#

Bases: BaseOperationConfig

Configuration for Block Sparse Attention operation.

Parameters
  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
class ejkernel.modules.operations.configs.FlashAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#

Bases: BaseOperationConfig

Configuration for Flash Attention operation.

Parameters
  • fwd_params – Forward kernel parameters (uses q_blocksize/kv_blocksize for tiling).

  • bwd_params – Backward kernel parameters (optional).

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
class ejkernel.modules.operations.configs.FlashMLAConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2)[source]#

Bases: BaseOperationConfig

Configuration for Flash Multi-head Latent Attention operation.

Parameters
  • block_q – Query block size (default: 128)

  • block_k – Key block size (default: 128)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 2)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_k: int = 128#
block_q: int = 128#
num_stages: int = 2#
num_warps: int = 4#
class ejkernel.modules.operations.configs.GLAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Gated Linear Attention operation.

Parameters
  • block_q – Query block size (default: 64)

  • block_k – Key block size (default: 64)

  • block_d – Head dimension block size (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_d: int = 64#
block_k: int = 64#
block_q: int = 64#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.GroupedMatmulConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_m: int = 128, block_n: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2, bypass_xla_tiling: bool = False)[source]#

Bases: BaseOperationConfig

Configuration for Grouped Matrix Multiplication operation.

Parameters
  • block_m – M dimension block size (default: 128)

  • block_n – N dimension block size (default: 128)

  • block_k – K dimension block size (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 2)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_k: int = 128#
block_m: int = 128#
block_n: int = 128#
bypass_xla_tiling: bool = False#
num_stages: int = 2#
num_warps: int = 4#
class ejkernel.modules.operations.configs.KernelDeltaAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#

Bases: BaseOperationConfig

Configuration for Kernel Delta Attention (KDA) operation.

Note: This operation currently uses an XLA implementation without tunable block sizes. The config exists primarily for platform/backend selection.

Parameters
  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

class ejkernel.modules.operations.configs.LightningAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Lightning Attention operation.

Parameters
  • block_q – Query block size (default: 64)

  • block_k – Key block size (default: 64)

  • block_d – Head dimension block size (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_d: int = 64#
block_k: int = 64#
block_q: int = 64#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.MeanPoolingConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_size: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Mean Pooling operation.

Parameters
  • block_size – Block size for pooling (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_size: int = 64#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.NativeSparseAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, block_size: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Native Sparse Attention operation.

Parameters
  • block_q – Query block size (default: 64)

  • block_k – Key block size (default: 64)

  • block_d – Head dimension block size (default: 64)

  • block_size – Size of attention blocks for sparsity (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_d: int = 64#
block_k: int = 64#
block_q: int = 64#
block_size: int = 64#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.PageAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_splits: int = 0, pages_per_compute_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Page Attention operation.

Parameters
  • num_splits – Number of partitions for splitting contexts (default: 0 for auto)

  • pages_per_compute_block – Pages per compute block (default: None)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

num_splits: int = 0#
num_stages: int = 1#
num_warps: int = 4#
pages_per_compute_block: int | None = None#
class ejkernel.modules.operations.configs.PrefillPageAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Prefill Page Attention operation.

Parameters
  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None)[source]#

Bases: BaseOperationConfig

Configuration for Ragged Decode Attention operation.

Parameters
  • block_size – Block size for computation tiling (default: 256)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
class ejkernel.modules.operations.configs.RaggedPageAttentionv2Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Ragged Page Attention operation.

Parameters
  • num_kv_pages_per_block – Number of KV pages to process per compute block (default: None for auto)

  • num_queries_per_block – Number of queries to process per compute block (default: None for auto)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

num_kv_pages_per_block: int | None = None#
num_queries_per_block: int | None = None#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.RaggedPageAttentionv3Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Ragged Page Attention operation.

Parameters
  • num_kv_pages_per_block – Number of KV pages to process per compute block (default: None for auto)

  • num_queries_per_block – Number of queries to process per compute block (default: None for auto)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

chunk_prefill_size: int | None = None#
num_kv_pages_per_block: int | None = None#
num_queries_per_block: int | None = None#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.RecurrentAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#

Bases: BaseOperationConfig

Configuration for Recurrent Attention operation.

Parameters
  • block_q – Query block size (default: 64)

  • block_k – Key block size (default: 64)

  • block_d – Head dimension block size (default: 64)

  • num_warps – Number of warps for Triton kernels (default: 4)

  • num_stages – Number of pipeline stages (default: 1)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

block_d: int = 64#
block_k: int = 64#
block_q: int = 64#
num_stages: int = 1#
num_warps: int = 4#
class ejkernel.modules.operations.configs.RingAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#

Bases: BaseOperationConfig

Configuration for Ring Attention operation.

Parameters
  • fwd_params – Forward pass block size parameters

  • bwd_params – Backward pass block size parameters

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
class ejkernel.modules.operations.configs.ScaledDotProductAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#

Bases: BaseOperationConfig

Configuration for Scaled Dot Product Attention operation.

Note: This operation uses XLA primitives directly without tunable block sizes. The config exists primarily for platform/backend selection.

Parameters
  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

class ejkernel.modules.operations.configs.StateSpaceV1Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#

Bases: BaseOperationConfig

Configuration for SSM1 (Mamba1-style) Selective State Space operation.

Note: This operation uses XLA implementation primarily without tunable block sizes. The config exists primarily for platform/backend selection.

Parameters
  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

class ejkernel.modules.operations.configs.StateSpaceV2Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', n_groups: int = 1, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05)[source]#

Bases: BaseOperationConfig

Configuration for SSM2 (Mamba2-style) Selective State Space operation.

Parameters
  • n_groups – Number of groups for B, C parameters (default: 1)

  • use_gated_rmsnorm – Whether to use gated RMSNorm for output (default: False)

  • rmsnorm_eps – Epsilon for RMSNorm stability (default: 1e-5)

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

n_groups: int = 1#
rmsnorm_eps: float = 1e-05#
use_gated_rmsnorm: bool = False#
class ejkernel.modules.operations.configs.UnifiedAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, num_warps: int | None = None, num_stages: int | None = None)[source]#

Bases: BaseOperationConfig

Configuration for vLLM-style unified (paged) attention operation.

Parameters
  • seq_threshold_3d – Threshold (in #seqs) for selecting the segmented 3D decode kernel on GPU (Triton only).

  • num_par_softmax_segments – Number of parallel softmax segments used by the segmented 3D decode kernel (Triton only).

  • num_warps – Optional Triton kernel override.

  • num_stages – Optional Triton kernel override.

  • platform – Target platform (triton/pallas/cuda/xla/auto)

  • backend – Backend specification (default: “any”)

num_par_softmax_segments: int | None = None#
num_stages: int | None = None#
num_warps: int | None = None#
seq_threshold_3d: int | None = None#
ejkernel.modules.operations.configs.get_safe_hash_int(text, algorithm='md5')[source]#

Generate a hash of text using specified algorithm with safety checks.

ejkernel.modules.operations.configs.hash_fn(self) int[source]#

Generate a hash for an object based on its dictionary values.