Source code for ejkernel.kernels._triton.mean_pooling._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Mean pooling operations using Triton kernels.

This module provides GPU-accelerated mean pooling over sequence dimensions,
commonly used in NLP tasks to aggregate token embeddings into fixed-size
representations. The implementation uses custom Triton kernels for optimal
performance on GPUs.

Mean pooling computes the average of all token embeddings in a sequence,
producing a single vector representation. This is particularly useful for:
- Sentence/document embeddings in classification tasks
- Block compression in sparse attention mechanisms
- Sequence-level representations for downstream tasks

The implementation supports:
- Standard batched sequences with uniform lengths
- Variable-length sequences via cumulative sequence lengths (cu_seqlens)
- Efficient GPU parallelization via Triton
- Full automatic differentiation support

Example:
    >>> import jax.numpy as jnp
    >>> from ejkernel.kernels._triton.mean_pooling import mean_pooling
    >>>
    >>>
    >>> batch, seq_len, hidden_dim = 4, 128, 768
    >>> x = jnp.ones((batch, seq_len, hidden_dim))
    >>>
    >>>
    >>> pooled = mean_pooling(x, chunk_size=32)
    >>> print(pooled.shape)
"""

from functools import partial

import jax
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry
from ._triton_impl_bwd import bwd_triton_impl
from ._triton_impl_fwd import fwd_triton_impl


def _fwd_call(
    x: Float[Array, "... hidden_dim"],
    chunk_size: int,
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> tuple[Float[Array, "batch hidden_dim"], tuple[int, int, Int[Array, "num_seqs_plus_one"] | None]]:
    """
    Forward pass for mean pooling with custom VJP.

    Args:
        x: The input tensor.
        chunk_size: The chunk size for processing.
        cu_seqlens: Optional cumulative sequence lengths for variable-length sequences.

    Returns:
        A tuple containing the output of the forward pass and the residuals
        needed for the backward pass.
    """
    o = fwd_triton_impl(x=x, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
    residual = x.shape[0], x.shape[1], cu_seqlens
    return o, residual


def _bwd_call(
    chunk_size: int,
    residual: tuple[int, int, Int[Array, "num_seqs_plus_one"] | None],
    do: Float[Array, "batch hidden_dim"],
) -> Float[Array, "batch seq_len hidden_dim"]:
    """
    Backward pass for mean pooling with custom VJP.

    Args:
        chunk_size: The chunk size used in the forward pass.
        residual: Residuals saved from the forward pass.
        do: The gradient of the output tensor.

    Returns:
        The gradient with respect to the input tensor `x`.
    """
    A, B, cu_seqlens = residual
    dEo = bwd_triton_impl(do=do, batch_size=A, seq_len=B, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
    return dEo


@partial(jax.custom_vjp, nondiff_argnums=(1,))
@partial(jax.jit, static_argnums=(1,))
def _mean_pooling(
    x: Float[Array, "... hidden_dim"],
    chunk_size: int,
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> Float[Array, "batch hidden_dim"]:
    """
    Core JIT-compiled mean pooling function with a custom VJP.

    This is an internal function that directly calls the Triton implementation
    for the forward pass and is registered with JAX's custom differentiation
    system.

    Args:
        x: The input tensor.
        chunk_size: The chunk size for processing, a static argument for JIT.
        cu_seqlens: Optional cumulative sequence lengths for variable-length sequences.

    Returns:
        The mean-pooled output tensor.
    """
    return fwd_triton_impl(x=x, chunk_size=chunk_size, cu_seqlens=cu_seqlens)


_mean_pooling.defvjp(_fwd_call, _bwd_call)


[docs]@kernel_registry.register("mean_pooling", Platform.TRITON, Backend.GPU) @jaxtyping.jaxtyped(typechecker=beartype) def mean_pooling( x: Float[Array, "... hidden_dim"], chunk_size: int = 32, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> Float[Array, "... hidden_dim"]: """ Performs mean pooling over the sequence dimension using a Triton kernel. This function calculates the mean of token embeddings for each sequence in a batch. It is optimized for GPUs using a custom Triton kernel and supports both standard (padded) and variable-length sequences. Args: x: The input tensor of shape `(batch_size, sequence_length, hidden_dim)`. If `cu_seqlens` is provided for variable-length inputs, the shape should be `(total_tokens, hidden_dim)`. chunk_size: A performance-tuning parameter for the Triton kernel that determines how the input is chunked for processing. cu_seqlens: An optional 1D tensor of cumulative sequence lengths for handling variable-length sequences in a packed format. Example: `[0, len_seq1, len_seq1+len_seq2, ...]`. If provided, the function will compute the mean pooling for each of the packed sequences. Returns: A tensor of shape `(batch_size, hidden_dim)` containing the mean-pooled embeddings for each sequence. If `cu_seqlens` is used, the batch size in the output shape will correspond to the number of sequences defined by `cu_seqlens` (i.e., `len(cu_seqlens) - 1`). """ return _mean_pooling(x=x, chunk_size=chunk_size, cu_seqlens=cu_seqlens)