ejkernel.kernels._triton.mean_pooling._triton_impl_fwd

ejkernel.kernels._triton.mean_pooling._triton_impl_fwd#

ejkernel.kernels._triton.mean_pooling._triton_impl_fwd.fwd_triton_impl(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], chunk_size: int, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#

Execute mean pooling forward pass using Triton kernel.

Launches the Triton kernel for efficient mean pooling computation, handling both standard and variable-length sequence formats.

Parameters
  • x – Input tensor of shape [batch, seq_len, heads, dim] or packed format

  • chunk_size – Size of chunks for processing sequences

  • cu_seqlens – Optional cumulative sequence lengths for variable-length mode

Returns

Mean-pooled output tensor with reduced sequence dimension

Return type

jax.Array