ejkernel.kernels._xla.prefill_page_attention._impl#
XLA reference implementation of chunked prefill paged attention.
- ejkernel.kernels._xla.prefill_page_attention._impl.prefill_page_attention(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], *, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#
XLA reference implementation of chunked prefill paged attention.
This processes a chunk of query tokens during prefill phase with paged KV cache.
- Parameters
query – Query tensor [chunk_size, num_q_heads, head_dim] for prefill tokens.
key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim].
value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim].
context_len – Total context length including this chunk [1].
page_indices – Page indices for this sequence [num_pages].
softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim)).
mask_value – Value used for masked positions in attention.
attn_logits_soft_cap – Optional soft cap for attention logits.
sliding_window – If set, only attend to the last sliding_window tokens.
- Returns
Attention output [chunk_size, num_q_heads, head_dim].