ejkernel.kernels._xla.state_space_v2._interface

Contents

ejkernel.kernels._xla.state_space_v2._interface#

XLA implementation of SSM2 (Mamba2-style) selective state space.

This module provides a pure JAX/XLA implementation of the SSM2 algorithm used in Mamba2 and FalconH1 models.

Key characteristics of SSM2: - 1D A vector: [num_heads] (per-head scalar) - SSM state shape: [batch, num_heads, head_dim, ssm_state_size] - B, C with n_groups grouping

The algorithm:
Discretization:

dA = exp(A * dt) where A is per-head dB = dt * B

Recurrence (per head):

dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

ejkernel.kernels._xla.state_space_v2._interface.state_space_v2(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], A: Float[jaxlib._jax.Array, 'num_heads'], B: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], D: Float[jaxlib._jax.Array, 'num_heads'], dt: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None = None, n_groups: int = 1, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05, precision: jax._src.lax.lax.Precision | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None][source]#

SSM2 (Mamba2-style) selective state space using XLA backend.

Implements the Mamba2 architecture with O(N) complexity. Processes tokens sequentially with per-head scalar decay.

The core algorithm:
Discretization:

dA = exp(A * dt) where A is per-head scalar dB = dt * B

Recurrence (per head):

dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

Output gating (if gate provided):
If use_gated_rmsnorm:

y = rmsnorm(y) * act_fn(gate)

Else:

y = y * act_fn(gate)

Parameters
  • x – Input tensor Shape: [batch, seq_len, num_heads, head_dim]

  • A – A vector in real form (typically negative for stability) Shape: [num_heads]

  • B – B parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • C – C parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • D – Skip connection parameter Shape: [num_heads]

  • dt – Time step after softplus activation Shape: [batch, seq_len, num_heads]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, num_heads, head_dim, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, conv_dim, d_conv]

  • n_groups – Number of groups for B, C (B/C are repeated to num_heads)

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.

  • use_gated_rmsnorm – If True, apply RMSNorm before gating

  • rmsnorm_eps – Epsilon for RMSNorm stability

  • precision – JAX precision for matmul operations

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

Examples

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Basic usage
>>> batch, seq_len, num_heads, head_dim, n_groups, ssm_state_size = 2, 64, 8, 64, 1, 16
>>> x = random.normal(random.PRNGKey(0), (batch, seq_len, num_heads, head_dim))
>>> A = -random.uniform(random.PRNGKey(1), (num_heads,))  # negative for stability
>>> B = random.normal(random.PRNGKey(2), (batch, seq_len, n_groups, ssm_state_size))
>>> C = random.normal(random.PRNGKey(3), (batch, seq_len, n_groups, ssm_state_size))
>>> D = random.normal(random.PRNGKey(4), (num_heads,))
>>> dt = jax.nn.softplus(random.normal(random.PRNGKey(5), (batch, seq_len, num_heads)))
>>>
>>> output, ssm_state, conv_state = state_space_v2(x, A, B, C, D, dt, n_groups=n_groups)
>>> output.shape
(2, 64, 512)  # num_heads * head_dim
>>> ssm_state.shape
(2, 8, 64, 16)
>>>
>>> # With gated RMSNorm
>>> gate = random.normal(random.PRNGKey(6), (batch, seq_len, num_heads * head_dim))
>>> output, ssm_state, _ = state_space_v2(
...     x, A, B, C, D, dt,
...     gate=gate, n_groups=n_groups,
...     use_gated_rmsnorm=True, act_fn=jax.nn.silu,
... )