ejkernel.kernels._xla.prefill_page_attention._impl

ejkernel.kernels._xla.prefill_page_attention._impl#

XLA reference implementation of chunked prefill paged attention.

ejkernel.kernels._xla.prefill_page_attention._impl.prefill_page_attention(query: Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_len: Int[jaxlib._jax.Array, '1'], page_indices: Int[jaxlib._jax.Array, 'num_pages'], *, softmax_scale: float | None = None, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'chunk_size num_heads head_dim'][source]#

XLA reference implementation of chunked prefill paged attention.

This processes a chunk of query tokens during prefill phase with paged KV cache.

Parameters
  • query – Query tensor [chunk_size, num_q_heads, head_dim] for prefill tokens.

  • key_cache – Paged key cache [num_kv_heads, total_num_pages, page_size, head_dim].

  • value_cache – Paged value cache [num_kv_heads, total_num_pages, page_size, head_dim].

  • context_len – Total context length including this chunk [1].

  • page_indices – Page indices for this sequence [num_pages].

  • softmax_scale – Attention scaling factor (default: 1/sqrt(head_dim)).

  • mask_value – Value used for masked positions in attention.

  • attn_logits_soft_cap – Optional soft cap for attention logits.

  • sliding_window – If set, only attend to the last sliding_window tokens.

Returns

Attention output [chunk_size, num_q_heads, head_dim].