ejkernel.kernels._pallas.tpu.prefill_page_attention._interface#
- ejkernel.kernels._pallas.tpu.prefill_page_attention._interface.prefill_page_attention(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], *, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#
Chunked prefill attention with paged KV cache for TPU.
This kernel processes a chunk of query tokens during prefill phase, reading from a paged KV cache. It supports causal masking and sliding window.
- Parameters
query – Query tensor [chunk_size, num_q_heads, head_dim] for prefill tokens.
key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim].
value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim].
context_len – Total context length including this chunk [1].
page_indices – Page indices for this sequence [num_pages].
softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim)).
mask_value – Value used for masked positions in attention.
attn_logits_soft_cap – Optional soft cap for attention logits.
sliding_window – If set, only attend to the last sliding_window tokens.
- Returns
Attention output [chunk_size, num_q_heads, head_dim].
Note
This is designed for prefill phase where we process multiple query tokens
Uses causal masking so each query can only attend to itself and past tokens
The KV cache should already contain the keys/values for this sequence
chunk_size must be divisible by page_size