ejkernel.modules.operations.unified_attention#
Unified (paged) attention module with automatic platform selection.
This operation wraps the vLLM-style unified attention kernel implemented in ejkernel.kernels and provides a high-level API consistent with other ejkernel.modules.operations entry points.
The unified attention kernel targets inference workloads that use a paged KV cache and ragged query packing (variable-length sequences without padding).
- class ejkernel.modules.operations.unified_attention.UnifiedAttention[source]#
Bases:
Kernel[UnifiedAttentionConfig,Array]vLLM-style unified attention over a paged KV cache (inference-only).
This kernel implements unified paged attention for inference serving workloads, supporting both prefill and decode phases with a single kernel. It is designed to work with vLLM-style paged KV caches where key-value tensors are stored in fixed-size blocks that can be dynamically allocated and mapped per sequence.
The unified attention kernel automatically selects between different execution strategies based on sequence lengths: - For short sequences: Uses a 2D grid launch for better parallelism - For long sequences: Uses a 3D grid launch with parallel softmax reduction
- Features:
Paged KV cache support with block tables for memory efficiency
Ragged query packing (variable-length sequences without padding)
Automatic 2D/3D grid selection based on sequence characteristics
Support for GQA/MQA (grouped/multi-query attention)
Optional sliding window attention
Optional logits soft capping (Gemma-2 style)
Optional ALiBi position biases
Optional attention sinks for streaming inference
Example
>>> kernel = UnifiedAttention() >>> output = kernel.run( ... queries=packed_queries, # [total_tokens, num_q_heads, head_dim] ... key_cache=key_cache, # [num_blocks, block_size, num_kv_heads, head_dim] ... value_cache=value_cache, # [num_blocks, block_size, num_kv_heads, head_dim] ... kv_lens=context_lengths, # [num_seqs] ... block_tables=block_tables, # [num_seqs, max_blocks_per_seq] ... query_start_loc=cu_seqlens, # [num_seqs + 1] ... cfg=UnifiedAttentionConfig(), ... )
- candidate_cfgs(inv: Invocation[UnifiedAttentionConfig, Array])[source]#
Return candidate configurations for autotuning.
This operation exposes the main tuning knobs directly via the config, so autotuning is avoided by default to reduce overhead.
- Parameters
inv – Invocation containing the input arguments and metadata.
- Returns
Empty list (autotuning disabled for this kernel).
- get_impl(cfg: UnifiedAttentionConfig)[source]#
Get the platform-specific implementation.
- Parameters
cfg – Configuration specifying platform and backend preferences.
- Returns
Callable implementation function from the kernel registry.
- heuristic_cfg(inv: Invocation[UnifiedAttentionConfig, Array]) UnifiedAttentionConfig[source]#
Generate default configuration based on input characteristics.
Follows vLLM’s decode kernel selection heuristic to determine the sequence length threshold for switching between 2D and 3D grid launches.
- Parameters
inv – Invocation containing the input arguments and metadata.
- Returns
seq_threshold_3d: Sequence length above which 3D grid is used
num_par_softmax_segments: Number of parallel softmax reduction segments
- Return type
UnifiedAttentionConfig with heuristically determined parameters
- run(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: UnifiedAttentionConfig) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute unified paged attention.
- Parameters
queries – Packed query tensor of shape [total_tokens, num_q_heads, head_dim]. Contains all query tokens from all sequences concatenated together.
key_cache – Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated blocks storing key vectors for all sequences.
value_cache – Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated blocks storing value vectors for all sequences.
kv_lens – Context lengths per sequence of shape [num_seqs]. Number of valid KV tokens for each sequence.
block_tables – Block index mapping of shape [num_seqs, max_blocks_per_seq]. Maps logical block indices to physical block indices in the cache.
query_start_loc – Cumulative query token counts of shape [num_seqs + 1]. query_start_loc[i] gives the starting token index for sequence i.
softmax_scale – Scaling factor for attention scores. If None, uses 1/sqrt(head_dim).
causal – Whether to apply causal masking. Default True.
sliding_window – Optional sliding window size for local attention. If provided, each query only attends to the last sliding_window KV positions.
logits_soft_cap – Optional soft cap for attention logits (Gemma-2 style). Applies tanh-based capping: logits = soft_cap * tanh(logits / soft_cap).
alibi_slopes – Optional ALiBi slopes per head of shape [num_q_heads]. Adds position-dependent bias: bias[i,j] = slope * (j - i).
qq_bias – Optional query-query bias of shape [num_query_tokens, num_query_tokens]. Added directly to attention logits between query positions.
attention_sink – Optional attention sink values per head of shape [num_q_heads]. Adds constant attention to the first token for streaming inference stability.
platform – Override platform selection. One of “triton”, “pallas”, “cuda”, “xla”, “auto”.
cfg – Kernel configuration with tuning parameters.
- Returns
Output tensor of shape [total_tokens, num_q_heads, head_dim] with attention results.
- ejkernel.modules.operations.unified_attention.unified_attention(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], /, *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.UnifiedAttentionConfig | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#
Execute unified paged attention with automatic platform selection.
This is the main entry point for vLLM-style unified attention, suitable for inference serving workloads with paged KV caches. It handles both prefill and decode phases efficiently using a single unified kernel.
The function automatically selects the optimal platform (Triton, Pallas, XLA) based on available hardware and applies heuristic-based configuration tuning.
- Parameters
queries – Packed query tensor of shape [total_tokens, num_q_heads, head_dim]. All query tokens from all sequences are concatenated together without padding.
key_cache – Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated memory blocks storing key vectors. Blocks are shared across sequences and mapped via block_tables.
value_cache – Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated memory blocks storing value vectors, with same layout as key_cache.
kv_lens – Context lengths per sequence of shape [num_seqs]. Specifies how many KV tokens are valid for each sequence.
block_tables – Block index mapping of shape [num_seqs, max_blocks_per_seq]. Maps each sequence’s logical block indices to physical block indices in the cache. For sequence i, block_tables[i, j] gives the physical block index for logical block j.
query_start_loc – Cumulative query token counts of shape [num_seqs + 1]. Defines the boundaries of each sequence in the packed queries tensor. Sequence i’s queries span indices [query_start_loc[i], query_start_loc[i+1]).
softmax_scale – Scaling factor for attention scores. Default: 1/sqrt(head_dim).
causal – Whether to apply causal masking. Default: True. When True, each query can only attend to KV positions at or before its position.
sliding_window – Optional sliding window size for local attention. If provided, each query only attends to the most recent sliding_window KV positions.
logits_soft_cap – Optional soft cap for attention logits (Gemma-2 style). Applies: logits = soft_cap * tanh(logits / soft_cap) to prevent extreme values.
alibi_slopes – Optional ALiBi position bias slopes per head [num_q_heads]. Adds linear position-dependent bias to attention scores.
qq_bias – Optional query-query attention bias [num_query_tokens, num_query_tokens]. Directly added to attention logits for query-to-query interactions.
attention_sink – Optional attention sink values per head [num_q_heads]. Provides stable attention to the first token for streaming/infinite context.
platform – Override automatic platform selection. One of: - “triton”: Force Triton (GPU) - “pallas”: Force Pallas (TPU/GPU) - “xla”: Force XLA fallback - “cuda”: Force CUDA - “auto”: Automatic selection (default) - None: Same as “auto”
cfg – Optional configuration override. If None, uses heuristic-based defaults.
- Returns
Output tensor of shape [total_tokens, num_q_heads, head_dim] containing the attention results for all sequences, packed in the same order as queries.
Example
>>> import jax.numpy as jnp >>> from ejkernel.modules import unified_attention >>> >>> # Setup for 2 sequences with different lengths >>> total_tokens = 5 # seq1 has 2 tokens, seq2 has 3 tokens >>> queries = jnp.ones((total_tokens, 8, 64)) # 8 heads, dim 64 >>> key_cache = jnp.ones((100, 16, 2, 64)) # 100 blocks, block_size=16, 2 KV heads >>> value_cache = jnp.ones((100, 16, 2, 64)) >>> kv_lens = jnp.array([32, 48]) # context lengths >>> block_tables = jnp.array([[0, 1], [2, 3, 4]]) # block mappings >>> query_start_loc = jnp.array([0, 2, 5]) # [0, 2) for seq1, [2, 5) for seq2 >>> >>> output = unified_attention( ... queries, key_cache, value_cache, ... kv_lens, block_tables, query_start_loc, ... causal=True, ... )
Note
This kernel is optimized for inference and does not support backward passes. For training, use flash_attention instead.