ejkernel.modules.operations.configs#
Operation-specific configuration classes.
This module defines configuration dataclasses for each attention operation, providing type-safe, operation-specific parameters for kernel execution and autotuning.
- class ejkernel.modules.operations.configs.AttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2)[source]#
Bases:
BaseOperationConfigConfiguration for basic Attention operation.
- Parameters
block_q – Query block size (default: 128)
block_k – Key block size (default: 128)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 2)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_k: int = 128#
- block_q: int = 128#
- num_stages: int = 2#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.BaseOperationConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#
Bases:
objectBase configuration for all operations.
- backend: str = 'any'#
- platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto'#
- class ejkernel.modules.operations.configs.BlockSparseAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#
Bases:
BaseOperationConfigConfiguration for Block Sparse Attention operation.
- Parameters
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
- fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
- class ejkernel.modules.operations.configs.FlashAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#
Bases:
BaseOperationConfigConfiguration for Flash Attention operation.
- Parameters
fwd_params – Forward kernel parameters (uses q_blocksize/kv_blocksize for tiling).
bwd_params – Backward kernel parameters (optional).
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
- fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
- class ejkernel.modules.operations.configs.FlashMLAConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2)[source]#
Bases:
BaseOperationConfigConfiguration for Flash Multi-head Latent Attention operation.
- Parameters
block_q – Query block size (default: 128)
block_k – Key block size (default: 128)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 2)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_k: int = 128#
- block_q: int = 128#
- num_stages: int = 2#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.GLAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Gated Linear Attention operation.
- Parameters
block_q – Query block size (default: 64)
block_k – Key block size (default: 64)
block_d – Head dimension block size (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_d: int = 64#
- block_k: int = 64#
- block_q: int = 64#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.GroupedMatmulConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_m: int = 128, block_n: int = 128, block_k: int = 128, num_warps: int = 4, num_stages: int = 2, bypass_xla_tiling: bool = False)[source]#
Bases:
BaseOperationConfigConfiguration for Grouped Matrix Multiplication operation.
- Parameters
block_m – M dimension block size (default: 128)
block_n – N dimension block size (default: 128)
block_k – K dimension block size (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 2)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_k: int = 128#
- block_m: int = 128#
- block_n: int = 128#
- bypass_xla_tiling: bool = False#
- num_stages: int = 2#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.KernelDeltaAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#
Bases:
BaseOperationConfigConfiguration for Kernel Delta Attention (KDA) operation.
Note: This operation currently uses an XLA implementation without tunable block sizes. The config exists primarily for platform/backend selection.
- Parameters
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- class ejkernel.modules.operations.configs.LightningAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Lightning Attention operation.
- Parameters
block_q – Query block size (default: 64)
block_k – Key block size (default: 64)
block_d – Head dimension block size (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_d: int = 64#
- block_k: int = 64#
- block_q: int = 64#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.MeanPoolingConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_size: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Mean Pooling operation.
- Parameters
block_size – Block size for pooling (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_size: int = 64#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.NativeSparseAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, block_size: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Native Sparse Attention operation.
- Parameters
block_q – Query block size (default: 64)
block_k – Key block size (default: 64)
block_d – Head dimension block size (default: 64)
block_size – Size of attention blocks for sparsity (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_d: int = 64#
- block_k: int = 64#
- block_q: int = 64#
- block_size: int = 64#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.PageAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_splits: int = 0, pages_per_compute_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Page Attention operation.
- Parameters
num_splits – Number of partitions for splitting contexts (default: 0 for auto)
pages_per_compute_block – Pages per compute block (default: None)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- num_splits: int = 0#
- num_stages: int = 1#
- num_warps: int = 4#
- pages_per_compute_block: int | None = None#
- class ejkernel.modules.operations.configs.PrefillPageAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Prefill Page Attention operation.
- Parameters
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.RaggedDecodeAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None)[source]#
Bases:
BaseOperationConfigConfiguration for Ragged Decode Attention operation.
- Parameters
block_size – Block size for computation tiling (default: 256)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
- class ejkernel.modules.operations.configs.RaggedPageAttentionv2Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Ragged Page Attention operation.
- Parameters
num_kv_pages_per_block – Number of KV pages to process per compute block (default: None for auto)
num_queries_per_block – Number of queries to process per compute block (default: None for auto)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- num_kv_pages_per_block: int | None = None#
- num_queries_per_block: int | None = None#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.RaggedPageAttentionv3Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Ragged Page Attention operation.
- Parameters
num_kv_pages_per_block – Number of KV pages to process per compute block (default: None for auto)
num_queries_per_block – Number of queries to process per compute block (default: None for auto)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- chunk_prefill_size: int | None = None#
- num_kv_pages_per_block: int | None = None#
- num_queries_per_block: int | None = None#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.RecurrentAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', block_q: int = 64, block_k: int = 64, block_d: int = 64, num_warps: int = 4, num_stages: int = 1)[source]#
Bases:
BaseOperationConfigConfiguration for Recurrent Attention operation.
- Parameters
block_q – Query block size (default: 64)
block_k – Key block size (default: 64)
block_d – Head dimension block size (default: 64)
num_warps – Number of warps for Triton kernels (default: 4)
num_stages – Number of pipeline stages (default: 1)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- block_d: int = 64#
- block_k: int = 64#
- block_q: int = 64#
- num_stages: int = 1#
- num_warps: int = 4#
- class ejkernel.modules.operations.configs.RingAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None)[source]#
Bases:
BaseOperationConfigConfiguration for Ring Attention operation.
- Parameters
fwd_params – Forward pass block size parameters
bwd_params – Backward pass block size parameters
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- bwd_params: ejkernel.ops.utils.datacarrier.BwdParams | None = None#
- fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None#
- class ejkernel.modules.operations.configs.ScaledDotProductAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#
Bases:
BaseOperationConfigConfiguration for Scaled Dot Product Attention operation.
Note: This operation uses XLA primitives directly without tunable block sizes. The config exists primarily for platform/backend selection.
- Parameters
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- class ejkernel.modules.operations.configs.StateSpaceV1Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any')[source]#
Bases:
BaseOperationConfigConfiguration for SSM1 (Mamba1-style) Selective State Space operation.
Note: This operation uses XLA implementation primarily without tunable block sizes. The config exists primarily for platform/backend selection.
- Parameters
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- class ejkernel.modules.operations.configs.StateSpaceV2Config(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', n_groups: int = 1, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05)[source]#
Bases:
BaseOperationConfigConfiguration for SSM2 (Mamba2-style) Selective State Space operation.
- Parameters
n_groups – Number of groups for B, C parameters (default: 1)
use_gated_rmsnorm – Whether to use gated RMSNorm for output (default: False)
rmsnorm_eps – Epsilon for RMSNorm stability (default: 1e-5)
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- n_groups: int = 1#
- rmsnorm_eps: float = 1e-05#
- use_gated_rmsnorm: bool = False#
- class ejkernel.modules.operations.configs.UnifiedAttentionConfig(platform: Literal['triton', 'pallas', 'cuda', 'xla', 'auto'] = 'auto', backend: str = 'any', seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, num_warps: int | None = None, num_stages: int | None = None)[source]#
Bases:
BaseOperationConfigConfiguration for vLLM-style unified (paged) attention operation.
- Parameters
seq_threshold_3d – Threshold (in #seqs) for selecting the segmented 3D decode kernel on GPU (Triton only).
num_par_softmax_segments – Number of parallel softmax segments used by the segmented 3D decode kernel (Triton only).
num_warps – Optional Triton kernel override.
num_stages – Optional Triton kernel override.
platform – Target platform (triton/pallas/cuda/xla/auto)
backend – Backend specification (default: “any”)
- num_par_softmax_segments: int | None = None#
- num_stages: int | None = None#
- num_warps: int | None = None#
- seq_threshold_3d: int | None = None#