ejkernel.kernels._xla.state_space_v1._xla_impl_fwd#
Forward pass implementation for SSM1 (Mamba1-style) selective state space.
- SSM1 algorithm:
- Discretization:
dA = exp(A * dt) dB = dt * B
- Recurrence:
h_t = dA * h_{t-1} + dB * x_t y_t = h_t @ C_t + D * x_t
- Key characteristics:
2D A matrix: [intermediate_size, ssm_state_size]
SSM state shape: [batch, intermediate_size, ssm_state_size]