ejkernel.kernels._xla.state_space_v1._interface

Contents

ejkernel.kernels._xla.state_space_v1._interface#

XLA implementation of SSM1 (Mamba1-style) selective state space.

This module provides a pure JAX/XLA implementation of the SSM1 algorithm used in Mamba and FalconMamba models.

Key characteristics of SSM1: - 2D A matrix: [intermediate_size, ssm_state_size] - SSM state shape: [batch, intermediate_size, ssm_state_size] - Separate dt_proj projection for time step

The algorithm:
Discretization:

dA = exp(A * dt) dB = dt * B

Recurrence:

h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t

ejkernel.kernels._xla.state_space_v1._interface.state_space_v1(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#

SSM1 (Mamba1-style) selective state space using XLA backend.

Implements the original Mamba architecture with O(N) complexity. Processes tokens sequentially, maintaining a hidden state that accumulates information through discretized state transitions.

The core algorithm:
Discretization:

dA = exp(A * dt) dB = dt * B

Recurrence:

h_t = dA * h_{t-1} + dB * x_t y_t = sum(h_t * C_t) + D * x_t

Output gating (if gate provided):

y = y * act_fn(gate)

Parameters
  • hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]

  • A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]

  • B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • D – Skip connection parameter Shape: [intermediate_size]

  • dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

Examples

>>> import jax.numpy as jnp
>>> from jax import random
>>>
>>> # Basic usage
>>> batch, seq_len, d, n = 2, 64, 512, 16
>>> hidden_states = random.normal(random.PRNGKey(0), (batch, seq_len, d))
>>> A = -random.uniform(random.PRNGKey(1), (d, n))  # negative for stability
>>> B = random.normal(random.PRNGKey(2), (batch, seq_len, n))
>>> C = random.normal(random.PRNGKey(3), (batch, seq_len, n))
>>> D = random.normal(random.PRNGKey(4), (d,))
>>> dt = jax.nn.softplus(random.normal(random.PRNGKey(5), (batch, seq_len, d)))
>>>
>>> output, ssm_state, conv_state = state_space_v1(hidden_states, A, B, C, D, dt)
>>> output.shape
(2, 64, 512)
>>> ssm_state.shape
(2, 512, 16)
>>>
>>> # With gating
>>> gate = random.normal(random.PRNGKey(6), (batch, seq_len, d))
>>> output, ssm_state, _ = state_space_v1(
...     hidden_states, A, B, C, D, dt,
...     gate=gate, act_fn=jax.nn.silu,
... )