ejkernel.kernels._xla.state_space_v1._xla_impl_bwd

ejkernel.kernels._xla.state_space_v1._xla_impl_bwd#

Backward pass implementation for SSM1 (Mamba1-style) selective state space.

Gradients for the SSM1 recurrence:
Forward:

dA = exp(A * dt) dBx = dt * B * x h_t = dA * h_{t-1} + dBx_t y_t = sum(h_t * C_t) + D * x_t

Backward (computed in reverse order):

dL/dh_t = dL/dy_t * C_t + dA_{t+1} * dL/dh_{t+1} dL/dx_t = dL/dy_t * D + sum(dL/dh_t * dt * B) dL/dA = sum_t(dL/dh_t * h_{t-1} * A * dt) dL/dB_t = sum_d(dL/dh_t * dt * x) dL/dC_t = sum_d(dL/dy_t * h_t) dL/dD = sum_t(dL/dy_t * x_t) dL/ddt_t = sum(dL/dh_t * (A * dA * h_{t-1} + B * x))