ejkernel.kernels._triton.mean_pooling._interface

Contents

ejkernel.kernels._triton.mean_pooling._interface#

Mean pooling operations using Triton kernels.

This module provides GPU-accelerated mean pooling over sequence dimensions, commonly used in NLP tasks to aggregate token embeddings into fixed-size representations. The implementation uses custom Triton kernels for optimal performance on GPUs.

Mean pooling computes the average of all token embeddings in a sequence, producing a single vector representation. This is particularly useful for: - Sentence/document embeddings in classification tasks - Block compression in sparse attention mechanisms - Sequence-level representations for downstream tasks

The implementation supports: - Standard batched sequences with uniform lengths - Variable-length sequences via cumulative sequence lengths (cu_seqlens) - Efficient GPU parallelization via Triton - Full automatic differentiation support

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.mean_pooling import mean_pooling
>>>
>>>
>>> batch, seq_len, hidden_dim = 4, 128, 768
>>> x = jnp.ones((batch, seq_len, hidden_dim))
>>>
>>>
>>> pooled = mean_pooling(x, chunk_size=32)
>>> print(pooled.shape)
ejkernel.kernels._triton.mean_pooling._interface.mean_pooling(x: Float[jaxlib._jax.Array, '... hidden_dim'], chunk_size: int = 32, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) Float[jaxlib._jax.Array, '... hidden_dim'][source]#

Performs mean pooling over the sequence dimension using a Triton kernel.

This function calculates the mean of token embeddings for each sequence in a batch. It is optimized for GPUs using a custom Triton kernel and supports both standard (padded) and variable-length sequences.

Parameters
  • x – The input tensor of shape (batch_size, sequence_length, hidden_dim). If cu_seqlens is provided for variable-length inputs, the shape should be (total_tokens, hidden_dim).

  • chunk_size – A performance-tuning parameter for the Triton kernel that determines how the input is chunked for processing.

  • cu_seqlens – An optional 1D tensor of cumulative sequence lengths for handling variable-length sequences in a packed format. Example: [0, len_seq1, len_seq1+len_seq2, …]. If provided, the function will compute the mean pooling for each of the packed sequences.

Returns

A tensor of shape (batch_size, hidden_dim) containing the mean-pooled embeddings for each sequence. If cu_seqlens is used, the batch size in the output shape will correspond to the number of sequences defined by cu_seqlens (i.e., len(cu_seqlens) - 1).