ejkernel.modules.operations.pooling#

Pooling operation modules with automatic optimization.

This module implements efficient pooling operations for sequence data, optimized for JAX execution. Mean pooling is particularly useful for:

  • Sentence embeddings in NLP (pooling token representations)

  • Sequence classification (reducing sequence to fixed-size representation)

  • Feature aggregation across time steps

  • Dimensionality reduction in transformer outputs

The implementation supports variable-length sequences via cumulative sequence lengths, enabling efficient batched processing of sequences with different lengths.

class ejkernel.modules.operations.pooling.MeanPooling[source]#

Bases: Kernel[MeanPoolingConfig, Array]

Mean Pooling with custom optimization logic.

Computes the mean of sequence elements along the sequence dimension, with support for variable-length sequences and chunked processing for memory efficiency.

Features:
  • Efficient mean computation over sequence dimension

  • Support for variable-length sequences via cu_seqlens

  • Configurable chunk size for memory-efficient processing

  • Automatic platform selection (Triton/Pallas/XLA/CUDA)

  • Proper handling of padding in variable-length scenarios

This is commonly used to convert variable-length token sequences into fixed-size representations for classification or embedding tasks.

candidate_cfgs(inv: Invocation[MeanPoolingConfig, Array])[source]#

Generate candidate configurations for autotuning.

Mean pooling has tunable block_size for chunked processing. Generates configurations with varying block sizes to find optimal performance for the specific hardware and input dimensions.

Parameters

inv – Invocation object with arguments and metadata

Returns

List of candidate configurations with different block sizes (32, 64, 128) and corresponding warp/stage configurations

Note

Smaller block sizes (32) reduce memory usage but may have lower throughput. Larger block sizes (128) improve throughput for large sequences.

get_impl(cfg: MeanPoolingConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for mean pooling

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[MeanPoolingConfig, Array]) MeanPoolingConfig[source]#

Provide default configuration with block sizes.

Selects default block size and warp configuration based on typical sequence pooling workloads. These defaults work well for most cases but can be overridden via autotuning.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block_size=64, num_warps=4, num_stages=1

run(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], chunk_size: int = 32, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: MeanPoolingConfig) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#

Execute mean pooling over sequence dimension.

Parameters
  • x – Input tensor [batch, seq_len, hidden_dim]

  • chunk_size – Size of chunks for processing (default: 32)

  • cu_seqlens – Optional cumulative sequence lengths [num_seqs + 1] for variable-length sequences

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Pooled output [batch, hidden_dim]

Note

When cu_seqlens is provided, padding tokens are excluded from the mean computation, ensuring accurate pooling for variable-length sequences.

ejkernel.modules.operations.pooling.mean_pooling(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, chunk_size: int = 32, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.MeanPoolingConfig | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#

Execute mean pooling with automatic optimization.

Efficiently computes the mean of sequence elements along the sequence dimension, optimized for variable-length sequences and chunked processing.

Parameters
  • x – Input tensor [batch, seq_len, hidden_dim]

  • chunk_size – Size of chunks for processing (default: 32)

  • cu_seqlens

    Cumulative sequence lengths for variable-length sequences

    platform: Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Mean pooled output [batch, hidden_dim]

Example

>>>
>>> pooled = mean_pooling(x)
>>>
>>>
>>> pooled = mean_pooling(x, chunk_size=64)
>>>
>>>
>>> pooled = mean_pooling(x, cu_seqlens=cu_seqs)
    >>>
>>>
>>> out = mean_pooling(..., platform="triton")