ejkernel.modules.operations.state_space_v1#

SSM1 (Mamba1-style) Selective State Space operation module.

This module provides the StateSpaceV1 operation, implementing the original Mamba selective state space model architecture used by Mamba and FalconMamba.

Key characteristics of SSM1: - 2D A matrix: [intermediate_size, ssm_state_size] - SSM state shape: [batch, intermediate_size, ssm_state_size] - Separate dt_proj projection for time step - Output gating: y * activation(gate)

The algorithm:
Discretization:

dA = exp(A * dt) dB = dt * B

Recurrence:

h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t

Features:
  • O(N) complexity through sequential processing

  • Stateful computation with hidden state propagation

  • Optional gating with configurable activation

  • Conv state passthrough for caching

References

class ejkernel.modules.operations.state_space_v1.StateSpaceV1[source]#

Bases: Kernel[StateSpaceV1Config, Array]

SSM1 (Mamba1-style) Selective State Space operation.

Implements the original Mamba architecture with O(N) complexity. Processes tokens sequentially, maintaining a hidden state that accumulates information through discretized state transitions.

Features:
  • 2D A matrix [intermediate_size, ssm_state_size]

  • Hidden state propagation across timesteps

  • Optional gating with activation function

  • Conv state passthrough for caching

  • Multiple platform support (XLA primary)

The state update mechanism:

dA = exp(A * dt) dB = dt * B h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t

candidate_cfgs(inv: Invocation[StateSpaceV1Config, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

SSM1 uses XLA implementation primarily, so candidates are minimal.

get_impl(cfg: StateSpaceV1Config)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for SSM1

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[StateSpaceV1Config, Array]) StateSpaceV1Config[source]#

Provide default configuration.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration for SSM1 operation

run(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: StateSpaceV1Config) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#

Execute SSM1 selective state space operation.

Parameters
  • hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]

  • A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]

  • B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • D – Skip connection parameter Shape: [intermediate_size]

  • dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu)

  • platform – Optional platform override

  • cfg – Kernel configuration object

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

ejkernel.modules.operations.state_space_v1.state_space_v1(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], /, gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, *, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.StateSpaceV1Config | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#

Execute SSM1 (Mamba1-style) selective state space with automatic optimization.

SSM1 processes sequences with stateful computation, maintaining hidden states across timesteps for O(N) complexity selective state space modeling.

Parameters
  • hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]

  • A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]

  • B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]

  • D – Skip connection parameter Shape: [intermediate_size]

  • dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]

  • gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]

  • initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]

  • conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]

  • act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional kernel configuration

Returns

  • output: SSM output [batch, seq_len, intermediate_size]

  • ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]

  • conv_state: Passed through conv_state (for caching)

Return type

Tuple of

Example

>>> # Basic usage
>>> output, ssm_state, _ = state_space_v1(hidden_states, A, B, C, D, dt)
>>>
>>> # With gating
>>> output, ssm_state, _ = state_space_v1(
...     hidden_states, A, B, C, D, dt,
...     gate=gate, act_fn=jax.nn.silu,
... )
>>>
>>> # Inference with cached state
>>> output, new_state, conv_state = state_space_v1(
...     hidden_states[:, :1, :],
...     A, B[:, :1, :], C[:, :1, :], D, dt[:, :1, :],
...     initial_state=ssm_state, conv_state=conv_state,
... )