ejkernel.kernels._xla.ragged_page_attention_v3._interface#
Ragged paged attention v3 interface for mixed prefill and decode.
This module provides the public API for the third-generation ragged paged attention that supports mixed prefill and decode operations in a single batch. Includes KV cache update functionality.
- ejkernel.kernels._xla.ragged_page_attention_v3._interface.ragged_page_attention_v3(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], keys: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], values: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], kv_cache: Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded'], kv_lens: Int32[jaxlib._jax.Array, 'max_num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'max_num_seqs_times_pages_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'max_num_seqs_plus_1'], distribution: Int32[jaxlib._jax.Array, '3'], attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, *, softmax_scale: float = 1.0, sliding_window: int | None = None, logits_soft_cap: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded']][source]#
Ragged paged attention that supports mixed prefill and decode.
- Parameters
queries – concatenated all sequences’ queries.
kv_pages – paged KV cache. Normally in HBM.
context_lens – padded kv lengths. Only the first num_seqs values are valid.
block_tables – the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid.
query_start_loc – the cumulative sum of the effective query lengths. Similar to context_lens, only the first num_seqs+1 values are valid.
num_seqs – the dynamic number of sequences.
softmax_scale – the softmax softmax_scale which will be applied to the Q@K^T.
sliding_window – the sliding window size for the attention.
logits_soft_cap – the logit soft cap for the attention.
mask_value – mask value for causal mask.
num_kv_pages_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.
num_queries_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.
vmem_limit_bytes – the vmem limit for the pallas kernel.
- Returns
The output of the attention.