ejkernel.kernels._xla.state_space_v2._interface#
XLA implementation of SSM2 (Mamba2-style) selective state space.
This module provides a pure JAX/XLA implementation of the SSM2 algorithm used in Mamba2 and FalconH1 models.
Key characteristics of SSM2: - 1D A vector: [num_heads] (per-head scalar) - SSM state shape: [batch, num_heads, head_dim, ssm_state_size] - B, C with n_groups grouping
- The algorithm:
- Discretization:
dA = exp(A * dt) where A is per-head dB = dt * B
- Recurrence (per head):
dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D
- ejkernel.kernels._xla.state_space_v2._interface.state_space_v2(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], A: Float[jaxlib._jax.Array, 'num_heads'], B: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], D: Float[jaxlib._jax.Array, 'num_heads'], dt: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None = None, n_groups: int = 1, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05, precision: jax._src.lax.lax.Precision | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None][source]#
SSM2 (Mamba2-style) selective state space using XLA backend.
Implements the Mamba2 architecture with O(N) complexity. Processes tokens sequentially with per-head scalar decay.
- The core algorithm:
- Discretization:
dA = exp(A * dt) where A is per-head scalar dB = dt * B
- Recurrence (per head):
dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D
- Output gating (if gate provided):
- If use_gated_rmsnorm:
y = rmsnorm(y) * act_fn(gate)
- Else:
y = y * act_fn(gate)
- Parameters
x – Input tensor Shape: [batch, seq_len, num_heads, head_dim]
A – A vector in real form (typically negative for stability) Shape: [num_heads]
B – B parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]
C – C parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]
D – Skip connection parameter Shape: [num_heads]
dt – Time step after softplus activation Shape: [batch, seq_len, num_heads]
gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]
initial_state – Optional initial SSM state for continuation Shape: [batch, num_heads, head_dim, ssm_state_size]
conv_state – Optional convolution state for caching (passed through) Shape: [batch, conv_dim, d_conv]
n_groups – Number of groups for B, C (B/C are repeated to num_heads)
act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.
use_gated_rmsnorm – If True, apply RMSNorm before gating
rmsnorm_eps – Epsilon for RMSNorm stability
precision – JAX precision for matmul operations
- Returns
output: SSM output [batch, seq_len, intermediate_size]
ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]
conv_state: Passed through conv_state (for caching)
- Return type
Tuple of
Examples
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Basic usage >>> batch, seq_len, num_heads, head_dim, n_groups, ssm_state_size = 2, 64, 8, 64, 1, 16 >>> x = random.normal(random.PRNGKey(0), (batch, seq_len, num_heads, head_dim)) >>> A = -random.uniform(random.PRNGKey(1), (num_heads,)) # negative for stability >>> B = random.normal(random.PRNGKey(2), (batch, seq_len, n_groups, ssm_state_size)) >>> C = random.normal(random.PRNGKey(3), (batch, seq_len, n_groups, ssm_state_size)) >>> D = random.normal(random.PRNGKey(4), (num_heads,)) >>> dt = jax.nn.softplus(random.normal(random.PRNGKey(5), (batch, seq_len, num_heads))) >>> >>> output, ssm_state, conv_state = state_space_v2(x, A, B, C, D, dt, n_groups=n_groups) >>> output.shape (2, 64, 512) # num_heads * head_dim >>> ssm_state.shape (2, 8, 64, 16) >>> >>> # With gated RMSNorm >>> gate = random.normal(random.PRNGKey(6), (batch, seq_len, num_heads * head_dim)) >>> output, ssm_state, _ = state_space_v2( ... x, A, B, C, D, dt, ... gate=gate, n_groups=n_groups, ... use_gated_rmsnorm=True, act_fn=jax.nn.silu, ... )