ejkernel.kernels._xla.state_space_v2._xla_impl_fwd#
Forward pass implementation for SSM2 (Mamba2-style) selective state space.
- SSM2 algorithm:
- Discretization:
dA = exp(A * dt) where A is per-head dB = dt * B
- Recurrence (per head):
dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D
- Key characteristics:
1D A vector: [num_heads] (per-head scalar)
SSM state shape: [batch, num_heads, head_dim, ssm_state_size]
B, C with n_groups grouping