ejkernel.kernels._triton.mean_pooling._triton_impl_bwd#
- ejkernel.kernels._triton.mean_pooling._triton_impl_bwd.bwd_triton_impl(do: Float[jaxlib._jax.Array, 'batch hidden_dim'], batch_size: int, seq_len: int, chunk_size: int, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'][source]#
Execute mean pooling backward pass using Triton kernel.
Computes gradients with respect to the input by distributing the output gradient across the sequence dimension.
- Parameters
do – Gradient tensor from downstream layers
batch_size – Batch size dimension
seq_len – Sequence length dimension
chunk_size – Size of chunks for processing
cu_seqlens – Optional cumulative sequence lengths for variable-length mode
- Returns
Gradient with respect to input tensor
- Return type