ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._interface#
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._interface.ragged_page_attention_v2(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, num_warps: int | None = None, num_stages: int | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Ragged paged attention that supports mixed prefill and decode.
- Parameters
queries – concatenated all sequences’ queries.
kv_pages – paged KV cache. Normally in HBM.
context_lens – padded kv lengths. Only the first num_seqs values are valid.
block_tables – the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid.
query_start_loc – the cumulative sum of the effective query lengths. Similar to context_lens, only the first num_seqs+1 values are valid.
num_seqs – the dynamic number of sequences.
softmax_scale – the softmax softmax_scale which will be applied to the Q@K^T.
sliding_window – the sliding window size for the attention.
logits_soft_cap – the logit soft cap for the attention.
mask_value – mask value for causal mask.
num_kv_pages_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.
num_queries_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.
vmem_limit_bytes – the vmem limit for the pallas kernel.
- Returns
The output of the attention.