ejkernel.kernels._triton.page_attention._interface

Contents

ejkernel.kernels._triton.page_attention._interface#

Page Attention implementation using Triton kernels.

This module implements paged attention, a memory-efficient attention mechanism designed for inference workloads where key-value caches are stored in paged memory blocks. This is particularly useful for serving large language models where managing KV cache memory efficiently is critical.

Paged attention addresses the challenge of dynamic memory allocation during autoregressive generation. Instead of allocating a large contiguous buffer for each sequence’s KV cache, memory is organized into fixed-size pages that can be allocated on-demand and potentially shared across sequences.

Key concepts: - KV Cache Pages: Fixed-size blocks storing key and value vectors - Block Tables: Mapping from logical positions to physical page indices - Variable Context Lengths: Each sequence can have different lengths - Memory Efficiency: Pages can be allocated/deallocated dynamically

Architecture benefits: 1. Reduced memory fragmentation 2. Support for extremely long contexts via pagination 3. Efficient batching of variable-length sequences 4. Memory sharing for prefix caching scenarios

The implementation supports two modes: 1. Single-partition mode (num_splits=1): Direct attention computation 2. Multi-partition mode (num_splits>1): Splits long contexts for parallelization

Features: - Grouped-query attention (GQA) and multi-query attention (MQA) - Automatic splitting for long contexts - Optimized memory access patterns - GPU-accelerated via Triton kernels

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.page_attention import page_attention
>>>
>>> num_seqs, num_heads, head_dim = 4, 8, 64
>>> num_blocks, num_kv_heads, block_size = 100, 8, 16
>>>
>>>
>>> query = jnp.ones((num_seqs, num_heads, head_dim))
>>>
>>>
>>> key_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>> value_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>>
>>>
>>> context_lens = jnp.array([50, 100, 75, 120])
>>> block_tables = jnp.zeros((num_seqs, 10), dtype=jnp.int32)
>>>
>>>
>>> output = page_attention(query, key_cache, value_cache, context_lens, block_tables)
Reference:

Efficient Memory Management for Large Language Model Serving with PagedAttention https://arxiv.org/abs/2309.06180

ejkernel.kernels._triton.page_attention._interface.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#

Compute paged attention with key-value caches stored in memory pages.

This function performs attention where the key-value cache is organized into fixed-size pages, enabling efficient memory management for LLM serving workloads. The implementation automatically decides whether to use single-partition or multi-partition computation based on context lengths and available resources.

Parameters
  • query – Query tensor of shape (num_seqs, num_heads, head_dim). Each row represents the query for one sequence in the batch (typically during autoregressive decoding).

  • key_cache – Paged key cache of shape (num_blocks, num_kv_heads, block_size, head_dim). Keys are stored in fixed-size blocks that can be non-contiguous.

  • value_cache – Paged value cache of shape (num_blocks, num_kv_heads, block_size, head_dim). Values are stored in fixed-size blocks matching key_cache organization.

  • context_lens – Length of context for each sequence, shape (num_seqs,). Specifies how many tokens in the KV cache are valid for each sequence.

  • block_tables – Mapping from logical blocks to physical blocks, shape (num_seqs, max_blocks). For each sequence, maps logical block indices to physical block indices in key_cache/value_cache.

  • attn_scale – Attention scaling factor. If None, defaults to 1/sqrt(head_dim).

  • max_context_len – Maximum context length across all sequences. If None, computed as the maximum of context_lens.

  • num_splits – Number of partitions for splitting long contexts. If 0, the implementation automatically determines the optimal number of splits. Set to 1 to force single-partition mode.

  • mask_value – Value to use for masked positions (default: -2.38e38).

  • attn_logits_soft_cap – Not supported in Triton implementation (raises error).

  • pages_per_compute_block – Not supported in Triton implementation (raises error).

  • megacore_mode – Not supported in Triton implementation (raises error).

  • inline_seq_dim – Must be True for Triton implementation (raises error if False).

Returns

Attention output of shape (num_seqs, num_heads, head_dim).

Raises
  • NotImplementedError – If unsupported parameters are provided (attn_logits_soft_cap, pages_per_compute_block, megacore_mode, or inline_seq_dim=False).

  • AssertionError – If head_size is not in {16, 32, 64, 128, 256, 512} or if block_size constraints are violated.

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.page_attention import page_attention
>>>
>>>
>>> num_seqs, num_heads, head_dim = 4, 8, 64
>>> query = jnp.ones((num_seqs, num_heads, head_dim))
>>>
>>>
>>> num_blocks, num_kv_heads, block_size = 100, 8, 16
>>> key_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>> value_cache = jnp.ones((num_blocks, num_kv_heads, block_size, head_dim))
>>>
>>>
>>> context_lens = jnp.array([32, 64, 48, 80])
>>>
>>>
>>>
>>> block_tables = jnp.array([
...     [0, 1, -1, -1, -1],
...     [2, 3, 4, 5, -1],
...     [6, 7, 8, -1, -1],
...     [9, 10, 11, 12, 13]
... ])
>>>
>>> output = page_attention(query, key_cache, value_cache, context_lens, block_tables)
>>> print(output.shape)