ejkernel.kernels._xla.ragged_page_attention_v2._interface#
Ragged paged attention v2 interface for variable-length batches.
This module provides the public API for paged attention with ragged (variable-length) sequences. Supports multiple query tokens per sequence and FlashAttention-style online softmax for memory-efficient computation.
- ejkernel.kernels._xla.ragged_page_attention_v2._interface.ragged_page_attention_v2(queries: ~jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], kv_pages: ~jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_combined_kv_heads head_dim'], context_lens: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs'], block_tables: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs pages_per_seq'], query_start_loc: ~jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'], num_seqs: jax.jaxlib._jax.Array | int, *, softmax_scale: float | None = None, logits_soft_cap: float | None = None, compute_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>, optimized: bool = False, sliding_window: int | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, mask_value: float | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, num_warps: int | None = None, num_stages: int | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Performs paged attention for batched, ragged sequences with optional attention sinks.
This function implements a FlashAttention-style algorithm to compute attention for multiple sequences of varying lengths. The Key-Value (KV) cache for these sequences is stored in non-contiguous memory blocks called “pages”. This is a common technique in LLM inference servers to manage memory efficiently.
The attention is computed by iterating through blocks of queries and, for each query block, iterating through the relevant blocks of key-value pages. An online softmax algorithm is used to compute the attention output in a single pass over the KV data, which is memory-efficient.
- Parameters
queries – The query tensor for all sequences, concatenated together. Shape: [total_query_tokens, num_q_heads, head_size].
kv_pages – The paged Key/value cache. Shape: [num_pages, page_size, num_kv_heads_combined, head_size].
context_lens – The total length of each sequence in the KV cache. Shape: [num_seqs].
block_tables – A map from each sequence to its list of physical page indices in the KV cache. Shape: [num_seqs, max_pages_per_sequence].
query_start_loc – The cumulative sum of query lengths for each sequence, used to index into the queries tensor. Shape: [num_seqs + 1].
num_seqs – The total number of sequences in the batch which should be a shape[1] int32.
softmax_scale – The scaling factor to apply to the attention scores before the softmax operation (typically 1 / sqrt(head_size)).
logits_soft_cap – An optional value to cap the attention scores with tanh.
compute_dtype – The dtype to use for computation (default: bfloat16).
optimized – Whether to use the optimized implementation (default: False).
sliding_window – Optional sliding window size for local attention.
softmax_aux – Optional attention sink logits of shape [num_q_heads]. Single value per query head that participates in softmax normalization but doesn’t contribute to output, allowing the model to absorb probability mass.
- Returns
The attention output tensor, with the same shape and dtype as queries. Shape: [total_query_tokens, num_q_heads, head_size].
Note
Attention sinks are learnable parameters that participate in the softmax normalization but don’t produce output. They allow the model to “dump” attention probability mass, improving numerical stability and model behavior.