ejkernel.kernels._xla.state_space_v1._xla_impl_bwd#
Backward pass implementation for SSM1 (Mamba1-style) selective state space.
- Gradients for the SSM1 recurrence:
- Forward:
dA = exp(A * dt) dBx = dt * B * x h_t = dA * h_{t-1} + dBx_t y_t = sum(h_t * C_t) + D * x_t
- Backward (computed in reverse order):
dL/dh_t = dL/dy_t * C_t + dA_{t+1} * dL/dh_{t+1} dL/dx_t = dL/dy_t * D + sum(dL/dh_t * dt * B) dL/dA = sum_t(dL/dh_t * h_{t-1} * A * dt) dL/dB_t = sum_d(dL/dh_t * dt * x) dL/dC_t = sum_d(dL/dy_t * h_t) dL/dD = sum_t(dL/dy_t * x_t) dL/ddt_t = sum(dL/dh_t * (A * dA * h_{t-1} + B * x))