ejkernel.modules.operations.pooling#
Pooling operation modules with automatic optimization.
This module implements efficient pooling operations for sequence data, optimized for JAX execution. Mean pooling is particularly useful for:
Sentence embeddings in NLP (pooling token representations)
Sequence classification (reducing sequence to fixed-size representation)
Feature aggregation across time steps
Dimensionality reduction in transformer outputs
The implementation supports variable-length sequences via cumulative sequence lengths, enabling efficient batched processing of sequences with different lengths.
- class ejkernel.modules.operations.pooling.MeanPooling[source]#
Bases:
Kernel[MeanPoolingConfig,Array]Mean Pooling with custom optimization logic.
Computes the mean of sequence elements along the sequence dimension, with support for variable-length sequences and chunked processing for memory efficiency.
- Features:
Efficient mean computation over sequence dimension
Support for variable-length sequences via cu_seqlens
Configurable chunk size for memory-efficient processing
Automatic platform selection (Triton/Pallas/XLA/CUDA)
Proper handling of padding in variable-length scenarios
This is commonly used to convert variable-length token sequences into fixed-size representations for classification or embedding tasks.
- candidate_cfgs(inv: Invocation[MeanPoolingConfig, Array])[source]#
Generate candidate configurations for autotuning.
Mean pooling has tunable block_size for chunked processing. Generates configurations with varying block sizes to find optimal performance for the specific hardware and input dimensions.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
List of candidate configurations with different block sizes (32, 64, 128) and corresponding warp/stage configurations
Note
Smaller block sizes (32) reduce memory usage but may have lower throughput. Larger block sizes (128) improve throughput for large sequences.
- get_impl(cfg: MeanPoolingConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for mean pooling
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[MeanPoolingConfig, Array]) MeanPoolingConfig[source]#
Provide default configuration with block sizes.
Selects default block size and warp configuration based on typical sequence pooling workloads. These defaults work well for most cases but can be overridden via autotuning.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration with block_size=64, num_warps=4, num_stages=1
- run(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], chunk_size: int = 32, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: MeanPoolingConfig) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#
Execute mean pooling over sequence dimension.
- Parameters
x – Input tensor [batch, seq_len, hidden_dim]
chunk_size – Size of chunks for processing (default: 32)
cu_seqlens – Optional cumulative sequence lengths [num_seqs + 1] for variable-length sequences
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Pooled output [batch, hidden_dim]
Note
When cu_seqlens is provided, padding tokens are excluded from the mean computation, ensuring accurate pooling for variable-length sequences.
- ejkernel.modules.operations.pooling.mean_pooling(x: Float[jaxlib._jax.Array, 'batch seq_len hidden_dim'], cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, chunk_size: int = 32, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.MeanPoolingConfig | None = None) Float[jaxlib._jax.Array, 'batch hidden_dim'][source]#
Execute mean pooling with automatic optimization.
Efficiently computes the mean of sequence elements along the sequence dimension, optimized for variable-length sequences and chunked processing.
- Parameters
x – Input tensor [batch, seq_len, hidden_dim]
chunk_size – Size of chunks for processing (default: 32)
cu_seqlens –
Cumulative sequence lengths for variable-length sequences
platform: Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Mean pooled output [batch, hidden_dim]
Example
>>> >>> pooled = mean_pooling(x) >>> >>> >>> pooled = mean_pooling(x, chunk_size=64) >>> >>> >>> pooled = mean_pooling(x, cu_seqlens=cu_seqs) >>> >>> >>> out = mean_pooling(..., platform="triton")