ejkernel.modules.operations.state_space_v1#
SSM1 (Mamba1-style) Selective State Space operation module.
This module provides the StateSpaceV1 operation, implementing the original Mamba selective state space model architecture used by Mamba and FalconMamba.
Key characteristics of SSM1: - 2D A matrix: [intermediate_size, ssm_state_size] - SSM state shape: [batch, intermediate_size, ssm_state_size] - Separate dt_proj projection for time step - Output gating: y * activation(gate)
- The algorithm:
- Discretization:
dA = exp(A * dt) dB = dt * B
- Recurrence:
h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t
- Features:
O(N) complexity through sequential processing
Stateful computation with hidden state propagation
Optional gating with configurable activation
Conv state passthrough for caching
References
FalconMamba: https://huggingface.co/tiiuae/falcon-mamba-7b
- class ejkernel.modules.operations.state_space_v1.StateSpaceV1[source]#
Bases:
Kernel[StateSpaceV1Config,Array]SSM1 (Mamba1-style) Selective State Space operation.
Implements the original Mamba architecture with O(N) complexity. Processes tokens sequentially, maintaining a hidden state that accumulates information through discretized state transitions.
- Features:
2D A matrix [intermediate_size, ssm_state_size]
Hidden state propagation across timesteps
Optional gating with activation function
Conv state passthrough for caching
Multiple platform support (XLA primary)
- The state update mechanism:
dA = exp(A * dt) dB = dt * B h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t
- candidate_cfgs(inv: Invocation[StateSpaceV1Config, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
SSM1 uses XLA implementation primarily, so candidates are minimal.
- get_impl(cfg: StateSpaceV1Config)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for SSM1
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[StateSpaceV1Config, Array]) StateSpaceV1Config[source]#
Provide default configuration.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration for SSM1 operation
- run(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: StateSpaceV1Config) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#
Execute SSM1 selective state space operation.
- Parameters
hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]
A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]
B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]
C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]
D – Skip connection parameter Shape: [intermediate_size]
dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]
gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]
initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]
conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]
act_fn – Optional activation function for gating (e.g., jax.nn.silu)
platform – Optional platform override
cfg – Kernel configuration object
- Returns
output: SSM output [batch, seq_len, intermediate_size]
ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]
conv_state: Passed through conv_state (for caching)
- Return type
Tuple of
- ejkernel.modules.operations.state_space_v1.state_space_v1(hidden_states: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], A: Float[jaxlib._jax.Array, 'intermediate_size ssm_state_size'], B: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], C: Float[jaxlib._jax.Array, 'batch seq_len ssm_state_size'], D: Float[jaxlib._jax.Array, 'intermediate_size'], dt: Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], /, gate: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'] | None = None, conv_state: jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None = None, *, act_fn: collections.abc.Callable[[jax.jaxlib._jax.Array], jax.jaxlib._jax.Array] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.StateSpaceV1Config | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len intermediate_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size ssm_state_size'], jaxtyping.Float[jaxlib._jax.Array, 'batch intermediate_size d_conv'] | None][source]#
Execute SSM1 (Mamba1-style) selective state space with automatic optimization.
SSM1 processes sequences with stateful computation, maintaining hidden states across timesteps for O(N) complexity selective state space modeling.
- Parameters
hidden_states – Input tensor after convolution and activation Shape: [batch, seq_len, intermediate_size]
A – A matrix in real form (typically negative for stability) Shape: [intermediate_size, ssm_state_size]
B – B parameter from input projection Shape: [batch, seq_len, ssm_state_size]
C – C parameter from input projection Shape: [batch, seq_len, ssm_state_size]
D – Skip connection parameter Shape: [intermediate_size]
dt – Time step after softplus activation Shape: [batch, seq_len, intermediate_size]
gate – Optional gating tensor for output modulation Shape: [batch, seq_len, intermediate_size]
initial_state – Optional initial SSM state for continuation Shape: [batch, intermediate_size, ssm_state_size]
conv_state – Optional convolution state for caching (passed through) Shape: [batch, intermediate_size, d_conv]
act_fn – Optional activation function for gating (e.g., jax.nn.silu). If gate is provided but act_fn is None, defaults to jax.nn.silu.
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional kernel configuration
- Returns
output: SSM output [batch, seq_len, intermediate_size]
ssm_state: Final hidden state [batch, intermediate_size, ssm_state_size]
conv_state: Passed through conv_state (for caching)
- Return type
Tuple of
Example
>>> # Basic usage >>> output, ssm_state, _ = state_space_v1(hidden_states, A, B, C, D, dt) >>> >>> # With gating >>> output, ssm_state, _ = state_space_v1( ... hidden_states, A, B, C, D, dt, ... gate=gate, act_fn=jax.nn.silu, ... ) >>> >>> # Inference with cached state >>> output, new_state, conv_state = state_space_v1( ... hidden_states[:, :1, :], ... A, B[:, :1, :], C[:, :1, :], D, dt[:, :1, :], ... initial_state=ssm_state, conv_state=conv_state, ... )