ejkernel.kernels._xla.state_space_v2._xla_impl_bwd

ejkernel.kernels._xla.state_space_v2._xla_impl_bwd#

Backward pass implementation for SSM2 (Mamba2-style) selective state space.

Gradients for the SSM2 recurrence:
Forward:

dA = exp(A * dt) dBx = dt * B * x (outer product) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

Backward (computed in reverse order):

dL/dh_t = einsum(dL/dy_t, C_t) + dA_{t+1} * dL/dh_{t+1} dL/dx_t = dL/dy_t * D + einsum(dL/dh_t, dt * B) dL/dA = sum_t(dL/dh_t * h_{t-1} * dt * dA) dL/dB_t = einsum(dL/dh_t, x_t * dt) dL/dC_t = einsum(dL/dy_t, h_t) dL/dD = sum_t(dL/dy_t * x_t) dL/ddt_t = sum(dL/dh_t * (A * dA * h_{t-1} + B * x))