ejkernel.modules.operations.state_space_v2#

SSM2 (Mamba2-style) Selective State Space operation module.

This module provides the StateSpaceV2 operation, implementing the Mamba2 selective state space model architecture used by Mamba2 and FalconH1.

Key characteristics of SSM2: - 1D A vector: [num_heads] (per-head scalar) - SSM state shape: [batch, num_heads, head_dim, ssm_state_size] - B, C with n_groups grouping - Output gating via gated RMSNorm or simple multiplication

The algorithm:
Discretization:

dA = exp(A * dt) where A is per-head dB = dt * B

Recurrence (per head):

dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

Features:
  • O(N) complexity through sequential processing

  • Per-head scalar decay (1D A vector)

  • n_groups support for B, C grouping

  • Gated RMSNorm output normalization

  • Conv state passthrough for caching

References

class ejkernel.modules.operations.state_space_v2.StateSpaceV2[source]#

Bases: Kernel[StateSpaceV2Config, Array]

SSM2 (Mamba2-style) Selective State Space operation.

Implements the Mamba2 architecture with O(N) complexity. Processes tokens sequentially with per-head scalar decay.

Features:
  • 1D A vector [num_heads] (per-head scalar)

  • n_groups for B, C grouping

  • Hidden state shape [batch, num_heads, head_dim, ssm_state_size]

  • Gated RMSNorm output normalization option

  • Multiple platform support (XLA primary)

The state update mechanism:

dA = exp(A * dt) where A is per-head scalar dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

candidate_cfgs(inv: Invocation[StateSpaceV2Config, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

SSM2 uses XLA implementation primarily, so candidates are minimal.

get_impl(cfg: StateSpaceV2Config)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for SSM2

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[StateSpaceV2Config, Array]) StateSpaceV2Config[source]#

Provide default configuration.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration for SSM2 operation

run(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], A: Float[jaxlib._jax.Array, 'num_heads'], B: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], D: Float[jaxlib._jax.Array, 'num_heads'], dt: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None = None, n_groups: int = 1, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05, precision: jax._src.lax.lax.Precision | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: StateSpaceV2Config) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None][source]#

Execute SSM2 selective state space operation.

Parameters
  • x – Input tensor Shape: [batch, seq_len, num_heads, head_dim]

  • A – A vector in real form (typically negative for stability) Shape: [num_heads]

  • B – B parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • C – C parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • D – Skip connection parameter Shape: [num_heads]

  • dt – Time step after softplus activation Shape: [batch, seq_len, num_heads]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, num_heads, head_dim, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, conv_dim, d_conv]

  • n_groups – Number of groups for B, C (B/C are repeated to num_heads)

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu)

  • use_gated_rmsnorm – If True, apply RMSNorm before gating

  • rmsnorm_eps – Epsilon for RMSNorm stability

  • precision – JAX precision for matmul operations

  • platform – Optional platform override

  • cfg – Kernel configuration object

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

ejkernel.modules.operations.state_space_v2.state_space_v2(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], A: Float[jaxlib._jax.Array, 'num_heads'], B: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len n_groups ssm_state_size'], D: Float[jaxlib._jax.Array, 'num_heads'], dt: Float[jaxlib._jax.Array, 'batch seq_len num_heads'], /, gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None = None, *, n_groups: int = 1, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, use_gated_rmsnorm: bool = False, rmsnorm_eps: float = 1e-05, precision: jax._src.lax.lax.Precision | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.StateSpaceV2Config | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch conv_dim d_conv'] | None][source]#

Execute SSM2 (Mamba2-style) selective state space with automatic optimization.

SSM2 processes sequences with stateful computation using per-head scalar decay, maintaining hidden states across timesteps for O(N) complexity.

Parameters
  • x – Input tensor Shape: [batch, seq_len, num_heads, head_dim]

  • A – A vector in real form (typically negative for stability) Shape: [num_heads]

  • B – B parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • C – C parameter (with n_groups grouping) Shape: [batch, seq_len, n_groups, ssm_state_size]

  • D – Skip connection parameter Shape: [num_heads]

  • dt – Time step after softplus activation Shape: [batch, seq_len, num_heads]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, num_heads, head_dim, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, conv_dim, d_conv]

  • n_groups – Number of groups for B, C (B/C are repeated to num_heads)

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.

  • use_gated_rmsnorm – If True, apply RMSNorm before gating

  • rmsnorm_eps – Epsilon for RMSNorm stability

  • precision – JAX precision for matmul operations

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional kernel configuration

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, num_heads, head_dim, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

Example

>>> # Basic usage
>>> output, ssm_state, _ = state_space_v2(x, A, B, C, D, dt, n_groups=1)
>>>
>>> # With gated RMSNorm
>>> output, ssm_state, _ = state_space_v2(
...     x, A, B, C, D, dt,
...     gate=gate, n_groups=1,
...     use_gated_rmsnorm=True, act_fn=jax.nn.silu,
... )
>>>
>>> # Inference with cached state
>>> output, new_state, conv_state = state_space_v2(
...     x[:, :1, :, :],
...     A, B[:, :1, :, :], C[:, :1, :, :], D, dt[:, :1, :],
...     initial_state=ssm_state, conv_state=conv_state,
...     n_groups=1,
... )