# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLA implementation of SSM1 (Mamba1-style) selective state space.
This module provides a pure JAX/XLA implementation of the SSM1 algorithm
used in Mamba and FalconMamba models.
Key characteristics of SSM1:
- 2D A matrix: [intermediate_size, ssm_state_size]
- SSM state shape: [batch, intermediate_size, ssm_state_size]
- Separate dt_proj projection for time step
The algorithm:
Discretization:
dA = exp(A * dt)
dB = dt * B
Recurrence:
h_t = dA * h_{t-1} + dB * x_t
y_t = h_t @ C_t + D * x_t
"""
from collections.abc import Callable
from functools import partial
import jax
import jax.numpy as jnp
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float
from ..._registry import Backend, Platform, kernel_registry
from ._xla_impl_bwd import _ssm1_bwd
from ._xla_impl_fwd import _ssm1_fwd, _ssm1_single_step_fwd
@partial(jax.custom_vjp, nondiff_argnums=(7,))
def _ssm1_core(
hidden_states: Float[Array, "batch seq_len intermediate_size"],
A: Float[Array, "intermediate_size ssm_state_size"],
B: Float[Array, "batch seq_len ssm_state_size"],
C: Float[Array, "batch seq_len ssm_state_size"],
D: Float[Array, "intermediate_size"],
dt: Float[Array, "batch seq_len intermediate_size"],
initial_state: Float[Array, "batch intermediate_size ssm_state_size"] | None = None,
use_single_step: bool = False,
) -> tuple[
Float[Array, "batch seq_len intermediate_size"],
Float[Array, "batch intermediate_size ssm_state_size"],
]:
"""Core SSM1 computation with custom VJP.
Args:
hidden_states: Input tensor [batch, seq_len, intermediate_size]
A: A matrix (real form, typically negative) [intermediate_size, ssm_state_size]
B: B parameter [batch, seq_len, ssm_state_size]
C: C parameter [batch, seq_len, ssm_state_size]
D: Skip connection [intermediate_size]
dt: Time step (after softplus) [batch, seq_len, intermediate_size]
initial_state: Initial hidden state [batch, intermediate_size, ssm_state_size]
use_single_step: If True and seq_len=1, use optimized single step
Returns:
Tuple of (output, final_state)
"""
_batch_size, seq_len, _intermediate_size = hidden_states.shape
_ssm_state_size = B.shape[-1]
if use_single_step and seq_len == 1 and initial_state is not None:
# Single step inference mode
y, final_state = _ssm1_single_step_fwd(
hidden_state=hidden_states[:, 0, :],
A=A,
B=B[:, 0, :],
C=C[:, 0, :],
D=D,
dt=dt[:, 0, :],
ssm_state=initial_state,
)
return y[:, None, :], final_state
else:
# Full sequence mode
output, _, final_state = _ssm1_fwd(
hidden_states=hidden_states,
A=A,
B=B,
C=C,
D=D,
dt=dt,
initial_state=initial_state,
)
return output, final_state
def _ssm1_fwd_rule(
hidden_states: Float[Array, "batch seq_len intermediate_size"],
A: Float[Array, "intermediate_size ssm_state_size"],
B: Float[Array, "batch seq_len ssm_state_size"],
C: Float[Array, "batch seq_len ssm_state_size"],
D: Float[Array, "intermediate_size"],
dt: Float[Array, "batch seq_len intermediate_size"],
initial_state: Float[Array, "batch intermediate_size ssm_state_size"] | None,
use_single_step: bool,
) -> tuple[
tuple[Float[Array, "batch seq_len intermediate_size"], Float[Array, "batch intermediate_size ssm_state_size"]],
tuple,
]:
"""Forward rule with residuals for backward."""
batch_size, seq_len, intermediate_size = hidden_states.shape
ssm_state_size = B.shape[-1]
if initial_state is None:
initial_state = jnp.zeros(
(batch_size, intermediate_size, ssm_state_size),
dtype=jnp.float32,
)
if use_single_step and seq_len == 1:
# Single step - no need for all hidden states
y, final_state = _ssm1_single_step_fwd(
hidden_state=hidden_states[:, 0, :],
A=A,
B=B[:, 0, :],
C=C[:, 0, :],
D=D,
dt=dt[:, 0, :],
ssm_state=initial_state,
)
# For single step, all_hidden_states is just the final state expanded
all_hidden_states = final_state[:, None, :, :]
output = y[:, None, :]
else:
output, all_hidden_states, final_state = _ssm1_fwd(
hidden_states=hidden_states,
A=A,
B=B,
C=C,
D=D,
dt=dt,
initial_state=initial_state,
)
residuals = (hidden_states, A, B, C, D, dt, all_hidden_states, initial_state)
return (output, final_state), residuals
def _ssm1_bwd_rule(
use_single_step: bool,
residuals: tuple,
grads: tuple,
) -> tuple:
"""Backward rule with custom implementation."""
hidden_states, A, B, C, D, dt, all_hidden_states, initial_state = residuals
do, dfinal_state = grads
dx, dA, dB, dC, dD, ddt, d_initial_state = _ssm1_bwd(
hidden_states=hidden_states,
A=A,
B=B,
C=C,
D=D,
dt=dt,
all_hidden_states=all_hidden_states,
initial_state=initial_state,
do=do,
dfinal_state=dfinal_state,
)
return (dx, dA, dB, dC, dD, ddt, d_initial_state)
_ssm1_core.defvjp(_ssm1_fwd_rule, _ssm1_bwd_rule)
[docs]@kernel_registry.register("state_space_v1", Platform.XLA, Backend.ANY)
@kernel_registry.register("ssm1", Platform.XLA, Backend.ANY)
@kernel_registry.register("mamba1", Platform.XLA, Backend.ANY)
@jaxtyping.jaxtyped(typechecker=beartype)
def state_space_v1(
hidden_states: Float[Array, "batch seq_len intermediate_size"],
A: Float[Array, "intermediate_size ssm_state_size"],
B: Float[Array, "batch seq_len ssm_state_size"],
C: Float[Array, "batch seq_len ssm_state_size"],
D: Float[Array, "intermediate_size"],
dt: Float[Array, "batch seq_len intermediate_size"],
gate: Float[Array, "batch seq_len intermediate_size"] | None = None,
initial_state: Float[Array, "batch intermediate_size ssm_state_size"] | None = None,
conv_state: Float[Array, "batch intermediate_size d_conv"] | None = None,
act_fn: Callable[[Array], Array] | None = None,
) -> tuple[
Float[Array, "batch seq_len intermediate_size"],
Float[Array, "batch intermediate_size ssm_state_size"],
Float[Array, "batch intermediate_size d_conv"] | None,
]:
"""SSM1 (Mamba1-style) selective state space using XLA backend.
Implements the original Mamba architecture with O(N) complexity.
Processes tokens sequentially, maintaining a hidden state that
accumulates information through discretized state transitions.
The core algorithm:
Discretization:
dA = exp(A * dt)
dB = dt * B
Recurrence:
h_t = dA * h_{t-1} + dB * x_t
y_t = sum(h_t * C_t) + D * x_t
Output gating (if gate provided):
y = y * act_fn(gate)
Args:
hidden_states: Input tensor after convolution and activation
Shape: [batch, seq_len, intermediate_size]
A: A matrix in real form (typically negative for stability)
Shape: [intermediate_size, ssm_state_size]
B: B parameter from input projection
Shape: [batch, seq_len, ssm_state_size]
C: C parameter from input projection
Shape: [batch, seq_len, ssm_state_size]
D: Skip connection parameter
Shape: [intermediate_size]
dt: Time step after softplus activation
Shape: [batch, seq_len, intermediate_size]
gate: Optional gating tensor for output modulation
Shape: [batch, seq_len, intermediate_size]
initial_state: Optional initial SSM state for continuation
Shape: [batch, intermediate_size, ssm_state_size]
conv_state: Optional convolution state for caching (passed through)
Shape: [batch, intermediate_size, d_conv]
act_fn: Optional activation function for gating (e.g., jax.nn.silu).
If gate is provided but act_fn is None, defaults to jax.nn.silu.
Returns:
Tuple of:
- output: SSM output [batch, seq_len, intermediate_size]
- ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]
- conv_state: Passed through conv_state (for caching)
Examples:
>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Basic usage
>>> batch, seq_len, d, n = 2, 64, 512, 16
>>> hidden_states = random.normal(random.PRNGKey(0), (batch, seq_len, d))
>>> A = -random.uniform(random.PRNGKey(1), (d, n)) # negative for stability
>>> B = random.normal(random.PRNGKey(2), (batch, seq_len, n))
>>> C = random.normal(random.PRNGKey(3), (batch, seq_len, n))
>>> D = random.normal(random.PRNGKey(4), (d,))
>>> dt = jax.nn.softplus(random.normal(random.PRNGKey(5), (batch, seq_len, d)))
>>>
>>> output, ssm_state, conv_state = state_space_v1(hidden_states, A, B, C, D, dt)
>>> output.shape
(2, 64, 512)
>>> ssm_state.shape
(2, 512, 16)
>>>
>>> # With gating
>>> gate = random.normal(random.PRNGKey(6), (batch, seq_len, d))
>>> output, ssm_state, _ = state_space_v1(
... hidden_states, A, B, C, D, dt,
... gate=gate, act_fn=jax.nn.silu,
... )
"""
_, seq_len, _ = hidden_states.shape
dtype = hidden_states.dtype
# Determine if we should use single step optimization
use_single_step = seq_len == 1 and initial_state is not None
output, ssm_state = _ssm1_core(
hidden_states=hidden_states,
A=A,
B=B,
C=C,
D=D,
dt=dt,
initial_state=initial_state,
use_single_step=use_single_step,
)
# Apply gating if provided
if gate is not None:
if act_fn is None:
act_fn = jax.nn.silu
output = output * act_fn(gate)
return output.astype(dtype), ssm_state.astype(dtype), conv_state