ejkernel.kernels._xla.state_space_v2._xla_impl_fwd

ejkernel.kernels._xla.state_space_v2._xla_impl_fwd#

Forward pass implementation for SSM2 (Mamba2-style) selective state space.

SSM2 algorithm:
Discretization:

dA = exp(A * dt) where A is per-head dB = dt * B

Recurrence (per head):

dBx = dt * B * x (outer product form) h_t = dA * h_{t-1} + dBx y_t = einsum(h_t, C_t) + x_t * D

Key characteristics:
  • 1D A vector: [num_heads] (per-head scalar)

  • SSM state shape: [batch, num_heads, head_dim, ssm_state_size]

  • B, C with n_groups grouping