# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SSM2 (Mamba2-style) Selective State Space operation module.
This module provides the StateSpaceV2 operation, implementing the Mamba2
selective state space model architecture used by Mamba2 and FalconH1.
Key characteristics of SSM2:
- 1D A vector: [num_heads] (per-head scalar)
- SSM state shape: [batch, num_heads, head_dim, ssm_state_size]
- B, C with n_groups grouping
- Output gating via gated RMSNorm or simple multiplication
The algorithm:
Discretization:
dA = exp(A * dt) where A is per-head
dB = dt * B
Recurrence (per head):
dBx = dt * B * x (outer product form)
h_t = dA * h_{t-1} + dBx
y_t = einsum(h_t, C_t) + x_t * D
Features:
- O(N) complexity through sequential processing
- Per-head scalar decay (1D A vector)
- n_groups support for B, C grouping
- Gated RMSNorm output normalization
- Conv state passthrough for caching
References:
- Mamba2: https://arxiv.org/abs/2405.21060
- FalconH1: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
"""
from __future__ import annotations
import os
import typing
from collections.abc import Callable
from typing import Literal
from jax import lax
from jaxtyping import Array, Float
from ejkernel.kernels._registry import Backend, kernel_registry
from ejkernel.ops import (
AutotunePolicy,
ConfigCache,
ConfigSelectorChain,
Executor,
Invocation,
Kernel,
Tuner,
)
from ejkernel.ops.config.persistent import PersistentCache
from ..base import detect_platform
from .configs import StateSpaceV2Config
[docs]class StateSpaceV2(Kernel[StateSpaceV2Config, Array]):
"""SSM2 (Mamba2-style) Selective State Space operation.
Implements the Mamba2 architecture with O(N) complexity.
Processes tokens sequentially with per-head scalar decay.
Features:
- 1D A vector [num_heads] (per-head scalar)
- n_groups for B, C grouping
- Hidden state shape [batch, num_heads, head_dim, ssm_state_size]
- Gated RMSNorm output normalization option
- Multiple platform support (XLA primary)
The state update mechanism:
dA = exp(A * dt) where A is per-head scalar
dBx = dt * B * x (outer product form)
h_t = dA * h_{t-1} + dBx
y_t = einsum(h_t, C_t) + x_t * D
"""
def __init__(self):
"""Initialize StateSpaceV2 module.
Sets up the kernel with the operation identifier for registry lookup
and configuration management.
"""
super().__init__(op_id="state_space_v2")
[docs] def get_impl(self, cfg: StateSpaceV2Config):
"""Get kernel implementation from registry.
Args:
cfg: Configuration specifying platform and backend
Returns:
Callable kernel implementation for SSM2
Raises:
ValueError: If no matching implementation is found
"""
platform = detect_platform("state_space_v2", cfg.platform)
return kernel_registry.get("state_space_v2", platform=platform, backend=cfg.backend)
[docs] def run(
self,
x: Float[Array, "batch seq_len num_heads head_dim"],
A: Float[Array, "num_heads"],
B: Float[Array, "batch seq_len n_groups ssm_state_size"],
C: Float[Array, "batch seq_len n_groups ssm_state_size"],
D: Float[Array, "num_heads"],
dt: Float[Array, "batch seq_len num_heads"],
gate: Float[Array, "batch seq_len intermediate_size"] | None = None,
initial_state: Float[Array, "batch num_heads head_dim ssm_state_size"] | None = None,
conv_state: Float[Array, "batch conv_dim d_conv"] | None = None,
n_groups: int = 1,
act_fn: Callable[[Array], Array] | None = None,
use_gated_rmsnorm: bool = False,
rmsnorm_eps: float = 1e-5,
precision: lax.Precision | None = None,
platform: Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
*,
cfg: StateSpaceV2Config,
) -> tuple[
Float[Array, "batch seq_len intermediate_size"],
Float[Array, "batch num_heads head_dim ssm_state_size"],
Float[Array, "batch conv_dim d_conv"] | None,
]:
"""Execute SSM2 selective state space operation.
Args:
x: Input tensor
Shape: [batch, seq_len, num_heads, head_dim]
A: A vector in real form (typically negative for stability)
Shape: [num_heads]
B: B parameter (with n_groups grouping)
Shape: [batch, seq_len, n_groups, ssm_state_size]
C: C parameter (with n_groups grouping)
Shape: [batch, seq_len, n_groups, ssm_state_size]
D: Skip connection parameter
Shape: [num_heads]
dt: Time step after softplus activation
Shape: [batch, seq_len, num_heads]
gate: Optional gating tensor for output modulation
Shape: [batch, seq_len, intermediate_size]
initial_state: Optional initial SSM state for continuation
Shape: [batch, num_heads, head_dim, ssm_state_size]
conv_state: Optional convolution state for caching (passed through)
Shape: [batch, conv_dim, d_conv]
n_groups: Number of groups for B, C (B/C are repeated to num_heads)
act_fn: Optional activation function for gating (e.g., jax.nn.silu)
use_gated_rmsnorm: If True, apply RMSNorm before gating
rmsnorm_eps: Epsilon for RMSNorm stability
precision: JAX precision for matmul operations
platform: Optional platform override
cfg: Kernel configuration object
Returns:
Tuple of:
- output: SSM output [batch, seq_len, intermediate_size]
- ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]
- conv_state: Passed through conv_state (for caching)
"""
if platform is not None:
cfg = StateSpaceV2Config(
n_groups=n_groups,
use_gated_rmsnorm=use_gated_rmsnorm,
rmsnorm_eps=rmsnorm_eps,
platform=platform,
backend=Backend.ANY if platform == "xla" else cfg.backend,
)
impl = self.get_impl(cfg)
return impl(
x=x,
A=A,
B=B,
C=C,
D=D,
dt=dt,
gate=gate,
initial_state=initial_state,
conv_state=conv_state,
n_groups=n_groups,
act_fn=act_fn,
use_gated_rmsnorm=use_gated_rmsnorm,
rmsnorm_eps=rmsnorm_eps,
precision=precision,
)
[docs] def heuristic_cfg(self, inv: Invocation[StateSpaceV2Config, Array]) -> StateSpaceV2Config:
"""Provide default configuration.
Args:
inv: Invocation object containing arguments and metadata
Returns:
Default configuration for SSM2 operation
"""
return StateSpaceV2Config(
n_groups=1,
use_gated_rmsnorm=False,
rmsnorm_eps=1e-5,
platform="auto",
backend="any",
)
[docs] def candidate_cfgs(self, inv: Invocation[StateSpaceV2Config, Array]):
"""Generate candidate configurations for autotuning.
Args:
inv: Invocation object containing arguments and metadata
Returns:
List of candidate configurations to benchmark during autotuning
Note:
SSM2 uses XLA implementation primarily, so candidates are minimal.
"""
return [
StateSpaceV2Config(platform="auto", backend="any"),
StateSpaceV2Config(platform="xla", backend="any"),
]
_state_space_v2_executor: Executor[StateSpaceV2Config, Array] = Executor(
ConfigSelectorChain(
cache=ConfigCache(),
policy=AutotunePolicy(
allow_autotune=True,
cache_miss_fallback=os.getenv("EJKERNEL_AUTOTUNE_POLICY", "autotune"),
validate_backward=True,
),
tuner=Tuner(warmup=5, iters=100),
persistent=PersistentCache("state_space_v2"),
)
)
[docs]def state_space_v2(
x: Float[Array, "batch seq_len num_heads head_dim"],
A: Float[Array, "num_heads"],
B: Float[Array, "batch seq_len n_groups ssm_state_size"],
C: Float[Array, "batch seq_len n_groups ssm_state_size"],
D: Float[Array, "num_heads"],
dt: Float[Array, "batch seq_len num_heads"],
/,
gate: Float[Array, "batch seq_len intermediate_size"] | None = None,
initial_state: Float[Array, "batch num_heads head_dim ssm_state_size"] | None = None,
conv_state: Float[Array, "batch conv_dim d_conv"] | None = None,
*,
n_groups: int = 1,
act_fn: Callable[[Array], Array] | None = None,
use_gated_rmsnorm: bool = False,
rmsnorm_eps: float = 1e-5,
precision: lax.Precision | None = None,
platform: typing.Literal["triton", "pallas", "cuda", "xla", "auto"] | None = None,
cfg: StateSpaceV2Config | None = None,
) -> tuple[
Float[Array, "batch seq_len intermediate_size"],
Float[Array, "batch num_heads head_dim ssm_state_size"],
Float[Array, "batch conv_dim d_conv"] | None,
]:
"""Execute SSM2 (Mamba2-style) selective state space with automatic optimization.
SSM2 processes sequences with stateful computation using per-head scalar decay,
maintaining hidden states across timesteps for O(N) complexity.
Args:
x: Input tensor
Shape: [batch, seq_len, num_heads, head_dim]
A: A vector in real form (typically negative for stability)
Shape: [num_heads]
B: B parameter (with n_groups grouping)
Shape: [batch, seq_len, n_groups, ssm_state_size]
C: C parameter (with n_groups grouping)
Shape: [batch, seq_len, n_groups, ssm_state_size]
D: Skip connection parameter
Shape: [num_heads]
dt: Time step after softplus activation
Shape: [batch, seq_len, num_heads]
gate: Optional gating tensor for output modulation
Shape: [batch, seq_len, intermediate_size]
initial_state: Optional initial SSM state for continuation
Shape: [batch, num_heads, head_dim, ssm_state_size]
conv_state: Optional convolution state for caching (passed through)
Shape: [batch, conv_dim, d_conv]
n_groups: Number of groups for B, C (B/C are repeated to num_heads)
act_fn: Optional activation function for gating (e.g., jax.nn.silu).
If gate is provided but act_fn is None, defaults to jax.nn.silu.
use_gated_rmsnorm: If True, apply RMSNorm before gating
rmsnorm_eps: Epsilon for RMSNorm stability
precision: JAX precision for matmul operations
platform: Specific platform to use ("triton", "pallas", "cuda", or "xla")
cfg: Optional kernel configuration
Returns:
Tuple of:
- output: SSM output [batch, seq_len, intermediate_size]
- ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]
- conv_state: Passed through conv_state (for caching)
Example:
>>> # Basic usage
>>> output, ssm_state, _ = state_space_v2(x, A, B, C, D, dt, n_groups=1)
>>>
>>> # With gated RMSNorm
>>> output, ssm_state, _ = state_space_v2(
... x, A, B, C, D, dt,
... gate=gate, n_groups=1,
... use_gated_rmsnorm=True, act_fn=jax.nn.silu,
... )
>>>
>>> # Inference with cached state
>>> output, new_state, conv_state = state_space_v2(
... x[:, :1, :, :],
... A, B[:, :1, :, :], C[:, :1, :, :], D, dt[:, :1, :],
... initial_state=ssm_state, conv_state=conv_state,
... n_groups=1,
... )
"""
return _state_space_v2_executor(
StateSpaceV2(),
x=x,
A=A,
B=B,
C=C,
D=D,
dt=dt,
gate=gate,
initial_state=initial_state,
conv_state=conv_state,
n_groups=n_groups,
act_fn=act_fn,
use_gated_rmsnorm=use_gated_rmsnorm,
rmsnorm_eps=rmsnorm_eps,
precision=precision,
platform=platform,
_cfg=cfg,
)