ejkernel.kernels._xla.ragged_page_attention_v3._interface

ejkernel.kernels._xla.ragged_page_attention_v3._interface#

Ragged paged attention v3 interface for mixed prefill and decode.

This module provides the public API for the third-generation ragged paged attention that supports mixed prefill and decode operations in a single batch. Includes KV cache update functionality.

ejkernel.kernels._xla.ragged_page_attention_v3._interface.ragged_page_attention_v3(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], keys: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], values: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], kv_cache: Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded'], kv_lens: Int32[jaxlib._jax.Array, 'max_num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'max_num_seqs_times_pages_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'max_num_seqs_plus_1'], distribution: Int32[jaxlib._jax.Array, '3'], attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, *, softmax_scale: float = 1.0, sliding_window: int | None = None, logits_soft_cap: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded']][source]#

Ragged paged attention that supports mixed prefill and decode.

Parameters
  • queries – concatenated all sequences’ queries.

  • kv_pages – paged KV cache. Normally in HBM.

  • context_lens – padded kv lengths. Only the first num_seqs values are valid.

  • block_tables – the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid.

  • query_start_loc – the cumulative sum of the effective query lengths. Similar to context_lens, only the first num_seqs+1 values are valid.

  • num_seqs – the dynamic number of sequences.

  • softmax_scale – the softmax softmax_scale which will be applied to the Q@K^T.

  • sliding_window – the sliding window size for the attention.

  • logits_soft_cap – the logit soft cap for the attention.

  • mask_value – mask value for causal mask.

  • num_kv_pages_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.

  • num_queries_per_block – number of kv pages to be processed in one flash attention block in the pallas kernel.

  • vmem_limit_bytes – the vmem limit for the pallas kernel.

Returns

The output of the attention.