Source code for ejkernel.kernels._xla.mean_pooling._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mean pooling interface for sequence embedding aggregation.

This module provides the public API for mean pooling operations over
sequence dimensions. Supports both fixed-length and variable-length
(packed) sequences with custom VJP for efficient gradient computation.
"""

from functools import partial

import jax
import jax.numpy as jnp
import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry


def _mean_pooling_varlen(
    x: Float[Array, "total_tokens hidden_dim"],
    cu_seqlens: Int[Array, "num_seqs_plus_one"],
) -> Float[Array, "num_seqs hidden_dim"]:
    """
    Mean pooling for variable-length sequences.

    Args:
        x: Input tensor of shape [total_tokens, hidden_dim]
        cu_seqlens: Cumulative sequence lengths [num_seqs + 1]

    Returns:
        Mean-pooled tensor of shape [num_seqs, hidden_dim]
    """
    num_seqs = len(cu_seqlens) - 1
    max_seq_len = jnp.max(cu_seqlens[1:] - cu_seqlens[:-1])

    def pool_sequence(i):
        start = cu_seqlens[i]
        end = cu_seqlens[i + 1]
        seq_len = end - start

        seq_tokens = jax.lax.dynamic_slice(x, (start, 0), (max_seq_len, x.shape[-1]))

        mask = jnp.arange(max_seq_len) < seq_len

        masked_tokens = jnp.where(mask[:, None], seq_tokens, 0)
        return jnp.sum(masked_tokens, axis=0) / seq_len

    return jax.vmap(pool_sequence)(jnp.arange(num_seqs))


def _mean_pooling_fixed(
    x: Float[Array, "batch seq_len hidden_dim"],
) -> Float[Array, "batch hidden_dim"]:
    """
    Mean pooling for fixed-length sequences.

    Args:
        x: Input tensor of shape [batch, seq_len, hidden_dim]

    Returns:
        Mean-pooled tensor of shape [batch, hidden_dim]
    """
    return jnp.mean(x, axis=1)


@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _mean_pooling_core(
    x: Float[Array, "... hidden_dim"],
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> Float[Array, "batch hidden_dim"]:
    """Core mean pooling implementation with custom VJP."""
    if cu_seqlens is not None:
        return _mean_pooling_varlen(x, cu_seqlens)
    else:
        return _mean_pooling_fixed(x)


def _mean_pooling_fwd(
    x: Float[Array, "... hidden_dim"],
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None,
) -> tuple[Float[Array, "batch hidden_dim"], tuple]:
    """Forward pass for mean pooling with residuals."""
    out = _mean_pooling_core(x, cu_seqlens)
    residual = (x.shape, cu_seqlens)
    return out, residual


def _mean_pooling_bwd(
    cu_seqlens: Int[Array, "num_seqs_plus_one"] | None,
    residual: tuple,
    g: Float[Array, "batch hidden_dim"],
) -> tuple[Float[Array, "... hidden_dim"]]:
    """Backward pass for mean pooling."""
    x_shape, _ = residual

    if cu_seqlens is not None:
        num_seqs = len(cu_seqlens) - 1

        def grad_sequence(i):
            start = cu_seqlens[i]
            end = cu_seqlens[i + 1]
            seq_len = end - start

            return jnp.tile(g[i] / seq_len, (seq_len, 1))

        dx_list = [grad_sequence(i) for i in range(num_seqs)]
        dx = jnp.concatenate(dx_list, axis=0)
    else:
        seq_len = x_shape[1]
        dx = jnp.tile(g[:, None, :], (1, seq_len, 1)) / seq_len

    return (dx,)


_mean_pooling_core.defvjp(_mean_pooling_fwd, _mean_pooling_bwd)


[docs]@kernel_registry.register("mean_pooling", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def mean_pooling( x: Float[Array, "... hidden_dim"], chunk_size: int = 32, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> Float[Array, "batch hidden_dim"]: """ Performs mean pooling over the sequence dimension using JAX/XLA. This function calculates the mean of token embeddings for each sequence in a batch. It supports both standard (padded) and variable-length sequences. Args: x: The input tensor of shape `(batch_size, sequence_length, hidden_dim)`. If `cu_seqlens` is provided for variable-length inputs, the shape should be `(total_tokens, hidden_dim)`. chunk_size: Performance tuning parameter (ignored in XLA, only used by Triton). cu_seqlens: An optional 1D tensor of cumulative sequence lengths for handling variable-length sequences in a packed format. Example: `[0, len_seq1, len_seq1+len_seq2, ...]`. If provided, the function will compute the mean pooling for each of the packed sequences. Returns: A tensor of shape `(batch_size, hidden_dim)` containing the mean-pooled embeddings for each sequence. If `cu_seqlens` is used, the batch size in the output shape will correspond to the number of sequences defined by `cu_seqlens` (i.e., `len(cu_seqlens) - 1`). Examples: >>> >>> x = jnp.ones((2, 10, 128)) >>> out = mean_pooling(x) >>> out.shape (2, 128) >>> >>> x = jnp.ones((25, 128)) >>> cu_seqlens = jnp.array([0, 10, 25]) >>> out = mean_pooling(x, cu_seqlens=cu_seqlens) >>> out.shape (2, 128) """ return _mean_pooling_core(x, cu_seqlens)