ejkernel.kernels._xla.state_space_v1._xla_impl_fwd

ejkernel.kernels._xla.state_space_v1._xla_impl_fwd#

Forward pass implementation for SSM1 (Mamba1-style) selective state space.

SSM1 algorithm:
Discretization:

dA = exp(A * dt) dB = dt * B

Recurrence:

h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t

Key characteristics:
  • 2D A matrix: [intermediate_size, ssm_state_size]

  • SSM state shape: [batch, intermediate_size, ssm_state_size]