ejkernel.kernels._xla.mean_pooling._interface#
Mean pooling interface for sequence embedding aggregation.
This module provides the public API for mean pooling operations over sequence dimensions. Supports both fixed-length and variable-length (packed) sequences with custom VJP for efficient gradient computation.
- ejkernel.kernels._xla.mean_pooling._interface.mean_pooling(x: Float[jaxlib._jax.Array, '... hidden_dim'], chunk_size: int = 32, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#
Performs mean pooling over the sequence dimension using JAX/XLA.
This function calculates the mean of token embeddings for each sequence in a batch. It supports both standard (padded) and variable-length sequences.
- Parameters
x – The input tensor of shape (batch_size, sequence_length, hidden_dim). If cu_seqlens is provided for variable-length inputs, the shape should be (total_tokens, hidden_dim).
chunk_size – Performance tuning parameter (ignored in XLA, only used by Triton).
cu_seqlens – An optional 1D tensor of cumulative sequence lengths for handling variable-length sequences in a packed format. Example: [0, len_seq1, len_seq1+len_seq2, …]. If provided, the function will compute the mean pooling for each of the packed sequences.
- Returns
A tensor of shape (batch_size, hidden_dim) containing the mean-pooled embeddings for each sequence. If cu_seqlens is used, the batch size in the output shape will correspond to the number of sequences defined by cu_seqlens (i.e., len(cu_seqlens) - 1).
Examples
>>> >>> x = jnp.ones((2, 10, 128)) >>> out = mean_pooling(x) >>> out.shape (2, 128)
>>> >>> x = jnp.ones((25, 128)) >>> cu_seqlens = jnp.array([0, 10, 25]) >>> out = mean_pooling(x, cu_seqlens=cu_seqlens) >>> out.shape (2, 128)