ejkernel.kernels._triton.mean_pooling._triton_impl_fwd#
- ejkernel.kernels._triton.mean_pooling._triton_impl_fwd.fwd_triton_impl(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], chunk_size: int, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#
Execute mean pooling forward pass using Triton kernel.
Launches the Triton kernel for efficient mean pooling computation, handling both standard and variable-length sequence formats.
- Parameters
x – Input tensor of shape [batch, seq_len, heads, dim] or packed format
chunk_size – Size of chunks for processing sequences
cu_seqlens – Optional cumulative sequence lengths for variable-length mode
- Returns
Mean-pooled output tensor with reduced sequence dimension
- Return type