ejkernel.modules.operations.prefill_page_attention#

Prefill Page Attention module with automatic optimization.

This module implements chunked prefill attention with paged KV cache, designed for the prefill phase of LLM inference. It complements the decode-only PageAttention by handling multiple query tokens with causal masking.

Key Features:
  • Processes multiple query tokens (chunk) during prefill

  • Causal masking for autoregressive generation

  • Sliding window attention support

  • Paged KV cache for memory efficiency

  • Grouped Query Attention (GQA) support

Use Cases:
  • Initial prompt processing in LLM serving

  • Chunked prefill for long contexts

  • Combined with PageAttention for full inference pipeline

Mathematical Foundation:
For query position q_pos in chunk:

output[q_pos] = sum_{kv_pos <= q_pos} softmax(Q[q_pos] @ K[kv_pos].T) @ V[kv_pos]

With sliding window (window_size W):

output[q_pos] = sum_{q_pos - W + 1 <= kv_pos <= q_pos} softmax(…) @ V[kv_pos]

References

class ejkernel.modules.operations.prefill_page_attention.PrefillPageAttention[source]#

Bases: Kernel[PrefillPageAttentionConfig, Array]

Prefill Page Attention with custom optimization logic.

Efficient chunked prefill attention over paged KV cache for serving workloads. Designed for the prefill phase where multiple query tokens are processed.

Features:
  • Chunked prefill with causal masking

  • Paged KV cache management for memory efficiency

  • Sliding window attention support

  • Grouped Query Attention (GQA) support

  • Logit soft capping for numerical stability

  • TPU optimized with async DMA prefetching

candidate_cfgs(inv: Invocation[PrefillPageAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ejkernel.modules.operations.configs.PrefillPageAttentionConfig | None = None, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None, check_vma: bool = False)[source]#

Create a shard_map wrapper for distributed execution.

Parameters
  • query – Query tensor [chunk_size, num_q_heads, head_dim]

  • key_cache – Paged key cache

  • value_cache – Paged value cache

  • context_len – Total context length

  • page_indices – Page indices for sequence

  • platform – Platform to use

  • cfg – Configuration for the kernel

  • softmax_scale – Attention scaling factor

  • mask_value – Value for masked positions

  • attn_logits_soft_cap – Soft cap value

  • sliding_window – Sliding window size

  • mesh – JAX mesh for distributed execution

  • in_specs – Partition specifications for input tensors

  • out_specs – Partition specifications for output tensor

  • check_vma – Whether to check for valid memory access patterns

Returns

Tuple of (shard_map_fn, call_args)

get_impl(cfg: PrefillPageAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend preferences

Returns

Callable kernel implementation for prefill page attention

heuristic_cfg(inv: Invocation[PrefillPageAttentionConfig, Array]) PrefillPageAttentionConfig[source]#

Provide default configuration optimized for prefill page attention.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration suitable for prefill workloads

run(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: PrefillPageAttentionConfig, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#

Execute prefill page attention over paged KV cache.

Processes a chunk of query tokens with causal attention over a paged KV cache. Each query can only attend to itself and previous positions.

Parameters
  • query – Query tensor [chunk_size, num_q_heads, head_dim] for prefill tokens

  • key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim]

  • value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim]

  • context_len – Total context length including this chunk [1]

  • page_indices – Page indices for this sequence [num_pages]

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

  • softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim))

  • mask_value – Value used for masked positions (default: -inf)

  • attn_logits_soft_cap – Optional soft cap for attention logits

  • sliding_window – If set, only attend to the last sliding_window tokens

Returns

Attention output [chunk_size, num_q_heads, head_dim]

Example

>>> # Process a chunk of 128 tokens during prefill
>>> output = prefill_page_attention(
...     query,  # [128, 32, 128]
...     key_cache, value_cache,
...     context_len=jnp.array([512]),
...     page_indices=page_indices,
... )
ejkernel.modules.operations.prefill_page_attention.prefill_page_attention(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], /, *, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.PrefillPageAttentionConfig | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#

Execute prefill page attention with automatic optimization.

Chunked prefill attention with paged KV cache for the prefill phase of LLM inference. Supports causal masking and sliding window attention.

Parameters
  • query – Query tensor [chunk_size, num_q_heads, head_dim]

  • key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim]

  • value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim]

  • context_len – Total context length including this chunk [1]

  • page_indices – Page indices for this sequence [num_pages]

  • softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim))

  • mask_value – Value for masked positions (default: -inf)

  • attn_logits_soft_cap – Soft cap value for attention logits

  • sliding_window – If set, only attend to the last sliding_window tokens

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Optional configuration override

Returns

Attention output [chunk_size, num_q_heads, head_dim]

Example

>>> # Basic prefill with 128 token chunk
>>> output = prefill_page_attention(
...     query, key_cache, value_cache,
...     context_len=jnp.array([512]),
...     page_indices=page_indices,
... )
>>> # With sliding window attention
>>> output = prefill_page_attention(
...     query, key_cache, value_cache,
...     context_len=jnp.array([1024]),
...     page_indices=page_indices,
...     sliding_window=256,
... )