Source code for ejkernel.modules.operations.configs

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Operation-specific configuration classes.

This module defines configuration dataclasses for each attention operation,
providing type-safe, operation-specific parameters for kernel execution
and autotuning.
"""

import hashlib
from dataclasses import dataclass
from typing import Literal

from ejkernel.ops import BwdParams, FwdParams


[docs]def get_safe_hash_int(text, algorithm="md5"): """Generate a hash of text using specified algorithm with safety checks.""" try: text_str = str(text) hash_object = getattr(hashlib, algorithm)(text_str.encode()) return int.from_bytes(hash_object.digest(), byteorder="big") except AttributeError as e: raise ValueError(f"Unsupported hash algorithm: {algorithm}") from e except Exception as e: raise Exception(f"Error generating hash: {e!s}") from e
[docs]def hash_fn(self) -> int: """Generate a hash for an object based on its dictionary values.""" shu = "".join(str(cu) for cu in self.__dict__.values() if isinstance(cu, float | int | bool | dict | list)) return get_safe_hash_int(shu)
[docs]@dataclass class BaseOperationConfig: """Base configuration for all operations.""" platform: Literal["triton", "pallas", "cuda", "xla", "auto"] = "auto" backend: str = "any" __hash__ = hash_fn
[docs]@dataclass class FlashAttentionConfig(BaseOperationConfig): """Configuration for Flash Attention operation. Args: fwd_params: Forward kernel parameters (uses `q_blocksize`/`kv_blocksize` for tiling). bwd_params: Backward kernel parameters (optional). platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ fwd_params: FwdParams | None = None bwd_params: BwdParams | None = None def __post_init__(self): if isinstance(self.fwd_params, dict): self.fwd_params = FwdParams(**self.fwd_params) if isinstance(self.bwd_params, dict): self.bwd_params = BwdParams(**self.bwd_params) __hash__ = hash_fn
[docs]@dataclass class BlockSparseAttentionConfig(BaseOperationConfig): """Configuration for Block Sparse Attention operation. Args: platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ fwd_params: FwdParams | None = None bwd_params: BwdParams | None = None def __post_init__(self): if isinstance(self.fwd_params, dict): self.fwd_params = FwdParams(**self.fwd_params) if isinstance(self.bwd_params, dict): self.bwd_params = BwdParams(**self.bwd_params) __hash__ = hash_fn
[docs]@dataclass class NativeSparseAttentionConfig(BaseOperationConfig): """Configuration for Native Sparse Attention operation. Args: block_q: Query block size (default: 64) block_k: Key block size (default: 64) block_d: Head dimension block size (default: 64) block_size: Size of attention blocks for sparsity (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 64 block_k: int = 64 block_d: int = 64 block_size: int = 64 num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class RecurrentAttentionConfig(BaseOperationConfig): """Configuration for Recurrent Attention operation. Args: block_q: Query block size (default: 64) block_k: Key block size (default: 64) block_d: Head dimension block size (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 64 block_k: int = 64 block_d: int = 64 num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class RingAttentionConfig(BaseOperationConfig): """Configuration for Ring Attention operation. Args: fwd_params: Forward pass block size parameters bwd_params: Backward pass block size parameters platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ fwd_params: FwdParams | None = None bwd_params: BwdParams | None = None def __post_init__(self): if isinstance(self.fwd_params, dict): self.fwd_params = FwdParams(**self.fwd_params) if isinstance(self.bwd_params, dict): self.bwd_params = BwdParams(**self.bwd_params) __hash__ = hash_fn
[docs]@dataclass class PageAttentionConfig(BaseOperationConfig): """Configuration for Page Attention operation. Args: num_splits: Number of partitions for splitting contexts (default: 0 for auto) pages_per_compute_block: Pages per compute block (default: None) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ num_splits: int = 0 pages_per_compute_block: int | None = None num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class UnifiedAttentionConfig(BaseOperationConfig): """Configuration for vLLM-style unified (paged) attention operation. Args: seq_threshold_3d: Threshold (in #seqs) for selecting the segmented 3D decode kernel on GPU (Triton only). num_par_softmax_segments: Number of parallel softmax segments used by the segmented 3D decode kernel (Triton only). num_warps: Optional Triton kernel override. num_stages: Optional Triton kernel override. platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ seq_threshold_3d: int | None = None num_par_softmax_segments: int | None = None num_warps: int | None = None num_stages: int | None = None __hash__ = hash_fn
[docs]@dataclass class AttentionConfig(BaseOperationConfig): """Configuration for basic Attention operation. Args: block_q: Query block size (default: 128) block_k: Key block size (default: 128) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 2) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 128 block_k: int = 128 num_warps: int = 4 num_stages: int = 2 __hash__ = hash_fn
[docs]@dataclass class GroupedMatmulConfig(BaseOperationConfig): """Configuration for Grouped Matrix Multiplication operation. Args: block_m: M dimension block size (default: 128) block_n: N dimension block size (default: 128) block_k: K dimension block size (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 2) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_m: int = 128 block_n: int = 128 block_k: int = 128 num_warps: int = 4 num_stages: int = 2 bypass_xla_tiling: bool = False __hash__ = hash_fn
[docs]@dataclass class MeanPoolingConfig(BaseOperationConfig): """Configuration for Mean Pooling operation. Args: block_size: Block size for pooling (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_size: int = 64 num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class RaggedDecodeAttentionConfig(BaseOperationConfig): """Configuration for Ragged Decode Attention operation. Args: block_size: Block size for computation tiling (default: 256) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ fwd_params: FwdParams | None = None def __post_init__(self): if isinstance(self.fwd_params, dict): self.fwd_params = FwdParams(**self.fwd_params) __hash__ = hash_fn
[docs]@dataclass class RaggedPageAttentionv2Config(BaseOperationConfig): """Configuration for Ragged Page Attention operation. Args: num_kv_pages_per_block: Number of KV pages to process per compute block (default: None for auto) num_queries_per_block: Number of queries to process per compute block (default: None for auto) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ num_kv_pages_per_block: int | None = None num_queries_per_block: int | None = None num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class RaggedPageAttentionv3Config(BaseOperationConfig): """Configuration for Ragged Page Attention operation. Args: num_kv_pages_per_block: Number of KV pages to process per compute block (default: None for auto) num_queries_per_block: Number of queries to process per compute block (default: None for auto) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ chunk_prefill_size: int | None = None num_kv_pages_per_block: int | None = None num_queries_per_block: int | None = None num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class GLAttentionConfig(BaseOperationConfig): """Configuration for Gated Linear Attention operation. Args: block_q: Query block size (default: 64) block_k: Key block size (default: 64) block_d: Head dimension block size (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 64 block_k: int = 64 block_d: int = 64 num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class LightningAttentionConfig(BaseOperationConfig): """Configuration for Lightning Attention operation. Args: block_q: Query block size (default: 64) block_k: Key block size (default: 64) block_d: Head dimension block size (default: 64) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 64 block_k: int = 64 block_d: int = 64 num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class KernelDeltaAttentionConfig(BaseOperationConfig): """Configuration for Kernel Delta Attention (KDA) operation. Note: This operation currently uses an XLA implementation without tunable block sizes. The config exists primarily for platform/backend selection. Args: platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ pass __hash__ = hash_fn
[docs]@dataclass class FlashMLAConfig(BaseOperationConfig): """Configuration for Flash Multi-head Latent Attention operation. Args: block_q: Query block size (default: 128) block_k: Key block size (default: 128) num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 2) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ block_q: int = 128 block_k: int = 128 num_warps: int = 4 num_stages: int = 2 __hash__ = hash_fn
[docs]@dataclass class ScaledDotProductAttentionConfig(BaseOperationConfig): """Configuration for Scaled Dot Product Attention operation. Note: This operation uses XLA primitives directly without tunable block sizes. The config exists primarily for platform/backend selection. Args: platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ pass __hash__ = hash_fn
[docs]@dataclass class PrefillPageAttentionConfig(BaseOperationConfig): """Configuration for Prefill Page Attention operation. Args: num_warps: Number of warps for Triton kernels (default: 4) num_stages: Number of pipeline stages (default: 1) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ num_warps: int = 4 num_stages: int = 1 __hash__ = hash_fn
[docs]@dataclass class StateSpaceV1Config(BaseOperationConfig): """Configuration for SSM1 (Mamba1-style) Selective State Space operation. Note: This operation uses XLA implementation primarily without tunable block sizes. The config exists primarily for platform/backend selection. Args: platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ __hash__ = hash_fn
[docs]@dataclass class StateSpaceV2Config(BaseOperationConfig): """Configuration for SSM2 (Mamba2-style) Selective State Space operation. Args: n_groups: Number of groups for B, C parameters (default: 1) use_gated_rmsnorm: Whether to use gated RMSNorm for output (default: False) rmsnorm_eps: Epsilon for RMSNorm stability (default: 1e-5) platform: Target platform (triton/pallas/cuda/xla/auto) backend: Backend specification (default: "any") """ n_groups: int = 1 use_gated_rmsnorm: bool = False rmsnorm_eps: float = 1e-5 __hash__ = hash_fn