ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._interface

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._interface#

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._interface.ragged_page_attention_v2(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, num_warps: int | None = None, num_stages: int | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Ragged paged attention that supports mixed prefill and decode.

Parameters
  • queries – concatenated all sequences’ queries.

  • kv_pages – paged KV cache. Normally in HBM.

  • context_lens – padded kv lengths. Only the first num_seqs values are valid.

  • block_tables – the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid.

  • query_start_loc – the cumulative sum of the effective query lengths. Similar to context_lens, only the first num_seqs+1 values are valid.

  • num_seqs – the dynamic number of sequences.

  • softmax_scale – the softmax softmax_scale which will be applied to the Q@K^T.

  • sliding_window – the sliding window size for the attention.

  • logits_soft_cap – the logit soft cap for the attention.

  • mask_value – mask value for causal mask.

  • num_kv_pages_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.

  • num_queries_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.

  • vmem_limit_bytes – the vmem limit for the pallas kernel.

Returns

The output of the attention.