ejkernel.modules.operations.prefill_page_attention#
Prefill Page Attention module with automatic optimization.
This module implements chunked prefill attention with paged KV cache, designed for the prefill phase of LLM inference. It complements the decode-only PageAttention by handling multiple query tokens with causal masking.
- Key Features:
Processes multiple query tokens (chunk) during prefill
Causal masking for autoregressive generation
Sliding window attention support
Paged KV cache for memory efficiency
Grouped Query Attention (GQA) support
- Use Cases:
Initial prompt processing in LLM serving
Chunked prefill for long contexts
Combined with PageAttention for full inference pipeline
- Mathematical Foundation:
- For query position q_pos in chunk:
output[q_pos] = sum_{kv_pos <= q_pos} softmax(Q[q_pos] @ K[kv_pos].T) @ V[kv_pos]
- With sliding window (window_size W):
output[q_pos] = sum_{q_pos - W + 1 <= kv_pos <= q_pos} softmax(…) @ V[kv_pos]
References
JetStream chunked prefill: AI-Hypercomputer/JetStream
PagedAttention (vLLM): https://arxiv.org/abs/2309.06180
- class ejkernel.modules.operations.prefill_page_attention.PrefillPageAttention[source]#
Bases:
Kernel[PrefillPageAttentionConfig,Array]Prefill Page Attention with custom optimization logic.
Efficient chunked prefill attention over paged KV cache for serving workloads. Designed for the prefill phase where multiple query tokens are processed.
- Features:
Chunked prefill with causal masking
Paged KV cache management for memory efficiency
Sliding window attention support
Grouped Query Attention (GQA) support
Logit soft capping for numerical stability
TPU optimized with async DMA prefetching
- candidate_cfgs(inv: Invocation[PrefillPageAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
- create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.PrefillPageAttentionConfig | None = None, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#
Create a shard_map wrapper for distributed execution.
- Parameters
query – Query tensor [chunk_size, num_q_heads, head_dim]
key_cache – Paged key cache
value_cache – Paged value cache
context_len – Total context length
page_indices – Page indices for sequence
platform – Platform to use
cfg – Configuration for the kernel
softmax_scale – Attention scaling factor
mask_value – Value for masked positions
attn_logits_soft_cap – Soft cap value
sliding_window – Sliding window size
mesh – JAX mesh for distributed execution
in_specs – Partition specifications for input tensors
out_specs – Partition specifications for output tensor
check_vma – Whether to check for valid memory access patterns
- Returns
Tuple of (shard_map_fn, call_args)
- get_impl(cfg: PrefillPageAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend preferences
- Returns
Callable kernel implementation for prefill page attention
- heuristic_cfg(inv: Invocation[PrefillPageAttentionConfig, Array]) PrefillPageAttentionConfig[source]#
Provide default configuration optimized for prefill page attention.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration suitable for prefill workloads
- run(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: PrefillPageAttentionConfig, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#
Execute prefill page attention over paged KV cache.
Processes a chunk of query tokens with causal attention over a paged KV cache. Each query can only attend to itself and previous positions.
- Parameters
query – Query tensor [chunk_size, num_q_heads, head_dim] for prefill tokens
key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim]
value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim]
context_len – Total context length including this chunk [1]
page_indices – Page indices for this sequence [num_pages]
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim))
mask_value – Value used for masked positions (default: -inf)
attn_logits_soft_cap – Optional soft cap for attention logits
sliding_window – If set, only attend to the last sliding_window tokens
- Returns
Attention output [chunk_size, num_q_heads, head_dim]
Example
>>> # Process a chunk of 128 tokens during prefill >>> output = prefill_page_attention( ... query, # [128, 32, 128] ... key_cache, value_cache, ... context_len=jnp.array([512]), ... page_indices=page_indices, ... )
- ejkernel.modules.operations.prefill_page_attention.prefill_page_attention(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], /, *, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.PrefillPageAttentionConfig | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#
Execute prefill page attention with automatic optimization.
Chunked prefill attention with paged KV cache for the prefill phase of LLM inference. Supports causal masking and sliding window attention.
- Parameters
query – Query tensor [chunk_size, num_q_heads, head_dim]
key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim]
value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim]
context_len – Total context length including this chunk [1]
page_indices – Page indices for this sequence [num_pages]
softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim))
mask_value – Value for masked positions (default: -inf)
attn_logits_soft_cap – Soft cap value for attention logits
sliding_window – If set, only attend to the last sliding_window tokens
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Optional configuration override
- Returns
Attention output [chunk_size, num_q_heads, head_dim]
Example
>>> # Basic prefill with 128 token chunk >>> output = prefill_page_attention( ... query, key_cache, value_cache, ... context_len=jnp.array([512]), ... page_indices=page_indices, ... )
>>> # With sliding window attention >>> output = prefill_page_attention( ... query, key_cache, value_cache, ... context_len=jnp.array([1024]), ... page_indices=page_indices, ... sliding_window=256, ... )