ejkernel.kernels._triton.mean_pooling._triton_impl_bwd

ejkernel.kernels._triton.mean_pooling._triton_impl_bwd#

ejkernel.kernels._triton.mean_pooling._triton_impl_bwd.bwd_triton_impl(do: Float[jaxlib._jax.Array, 'batch hidden_dim'], batch_size: int, seq_len: int, chunk_size: int, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'][source]#

Execute mean pooling backward pass using Triton kernel.

Computes gradients with respect to the input by distributing the output gradient across the sequence dimension.

Parameters
  • do – Gradient tensor from downstream layers

  • batch_size – Batch size dimension

  • seq_len – Sequence length dimension

  • chunk_size – Size of chunks for processing

  • cu_seqlens – Optional cumulative sequence lengths for variable-length mode

Returns

Gradient with respect to input tensor

Return type

jax.Array