ejkernel.kernels._xla.mean_pooling._interface

Contents

ejkernel.kernels._xla.mean_pooling._interface#

Mean pooling interface for sequence embedding aggregation.

This module provides the public API for mean pooling operations over sequence dimensions. Supports both fixed-length and variable-length (packed) sequences with custom VJP for efficient gradient computation.

ejkernel.kernels._xla.mean_pooling._interface.mean_pooling(x: Float[jaxlib._jax.Array, '... hidden_dim'], chunk_size: int = 32, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#

Performs mean pooling over the sequence dimension using JAX/XLA.

This function calculates the mean of token embeddings for each sequence in a batch. It supports both standard (padded) and variable-length sequences.

Parameters
  • x – The input tensor of shape (batch_size, sequence_length, hidden_dim). If cu_seqlens is provided for variable-length inputs, the shape should be (total_tokens, hidden_dim).

  • chunk_size – Performance tuning parameter (ignored in XLA, only used by Triton).

  • cu_seqlens – An optional 1D tensor of cumulative sequence lengths for handling variable-length sequences in a packed format. Example: [0, len_seq1, len_seq1+len_seq2, …]. If provided, the function will compute the mean pooling for each of the packed sequences.

Returns

A tensor of shape (batch_size, hidden_dim) containing the mean-pooled embeddings for each sequence. If cu_seqlens is used, the batch size in the output shape will correspond to the number of sequences defined by cu_seqlens (i.e., len(cu_seqlens) - 1).

Examples

>>>
>>> x = jnp.ones((2, 10, 128))
>>> out = mean_pooling(x)
>>> out.shape
(2, 128)
>>>
>>> x = jnp.ones((25, 128))
>>> cu_seqlens = jnp.array([0, 10, 25])
>>> out = mean_pooling(x, cu_seqlens=cu_seqlens)
>>> out.shape
(2, 128)