ejkernel.kernels._triton.recurrent._triton_impl_bwd

ejkernel.kernels._triton.recurrent._triton_impl_bwd#

ejkernel.kernels._triton.recurrent._triton_impl_bwd.bwd_triton_impl(q: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], k: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], v: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, o: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, do: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, dht: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None, jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None][source]#