ejkernel.kernels._xla.state_space_v2._xla_impl_bwd#
Backward pass implementation for SSM2 (Mamba2-style) selective state space.
- Gradients for the SSM2 recurrence:
- Forward:
dA = exp(A * dt) dBx = dt * B * x (outer product) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D
- Backward (computed in reverse order):
dL/dh_t = einsum(dL/dy_t, C_t) + dA_{t+1} * dL/dh_{t+1} dL/dx_t = dL/dy_t * D + einsum(dL/dh_t, dt * B) dL/dA = sum_t(dL/dh_t * h_{t-1} * dt * dA) dL/dB_t = einsum(dL/dh_t, x_t * dt) dL/dC_t = einsum(dL/dy_t, h_t) dL/dD = sum_t(dL/dy_t * x_t) dL/ddt_t = sum(dL/dh_t * (A * dA * h_{t-1} + B * x))