ejkernel.kernels._xla.state_space_v1._interface#
XLA implementation of SSM1 (Mamba1-style) selective state space.
This module provides a pure JAX/XLA implementation of the SSM1 algorithm used in Mamba and FalconMamba models.
Key characteristics of SSM1: - 2D A matrix: [intermediate_size, ssm_state_size] - SSM state shape: [batch, intermediate_size, ssm_state_size] - Separate dt_proj projection for time step
- The algorithm:
- Discretization:
dA = exp(A * dt) dB = dt * B
- Recurrence:
h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t
- ejkernel.kernels._xla.state_space_v1._interface.state_space_v1(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#
SSM1 (Mamba1-style) selective state space using XLA backend.
Implements the original Mamba architecture with O(N) complexity. Processes tokens sequentially, maintaining a hidden state that accumulates information through discretized state transitions.
- The core algorithm:
- Discretization:
dA = exp(A * dt) dB = dt * B
- Recurrence:
h_t = dA * h_{t-1} + dB * x_t y_t = sum(h_t * C_t) + D * x_t
- Output gating (if gate provided):
y = y * act_fn(gate)
- Parameters
hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]
A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]
B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]
C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]
D – Skip connection parameter Shape: [intermediate_size]
dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]
gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]
initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]
conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]
act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.
- Returns
output: SSM output [batch, seq_len, intermediate_size]
ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]
conv_state: Passed through conv_state (for caching)
- Return type
Tuple of
Examples
>>> import jax.numpy as jnp >>> from jax import random >>> >>> # Basic usage >>> batch, seq_len, d, n = 2, 64, 512, 16 >>> hidden_states = random.normal(random.PRNGKey(0), (batch, seq_len, d)) >>> A = -random.uniform(random.PRNGKey(1), (d, n)) # negative for stability >>> B = random.normal(random.PRNGKey(2), (batch, seq_len, n)) >>> C = random.normal(random.PRNGKey(3), (batch, seq_len, n)) >>> D = random.normal(random.PRNGKey(4), (d,)) >>> dt = jax.nn.softplus(random.normal(random.PRNGKey(5), (batch, seq_len, d))) >>> >>> output, ssm_state, conv_state = state_space_v1(hidden_states, A, B, C, D, dt) >>> output.shape (2, 64, 512) >>> ssm_state.shape (2, 512, 16) >>> >>> # With gating >>> gate = random.normal(random.PRNGKey(6), (batch, seq_len, d)) >>> output, ssm_state, _ = state_space_v1( ... hidden_states, A, B, C, D, dt, ... gate=gate, act_fn=jax.nn.silu, ... )