ejkernel.modules.operations.unified_attention#

Unified (paged) attention module with automatic platform selection.

This operation wraps the vLLM-style unified attention kernel implemented in ejkernel.kernels and provides a high-level API consistent with other ejkernel.modules.operations entry points.

The unified attention kernel targets inference workloads that use a paged KV cache and ragged query packing (variable-length sequences without padding).

class ejkernel.modules.operations.unified_attention.UnifiedAttention[source]#

Bases: Kernel[UnifiedAttentionConfig, Array]

vLLM-style unified attention over a paged KV cache (inference-only).

This kernel implements unified paged attention for inference serving workloads, supporting both prefill and decode phases with a single kernel. It is designed to work with vLLM-style paged KV caches where key-value tensors are stored in fixed-size blocks that can be dynamically allocated and mapped per sequence.

The unified attention kernel automatically selects between different execution strategies based on sequence lengths: - For short sequences: Uses a 2D grid launch for better parallelism - For long sequences: Uses a 3D grid launch with parallel softmax reduction

Features:
  • Paged KV cache support with block tables for memory efficiency

  • Ragged query packing (variable-length sequences without padding)

  • Automatic 2D/3D grid selection based on sequence characteristics

  • Support for GQA/MQA (grouped/multi-query attention)

  • Optional sliding window attention

  • Optional logits soft capping (Gemma-2 style)

  • Optional ALiBi position biases

  • Optional attention sinks for streaming inference

Example

>>> kernel = UnifiedAttention()
>>> output = kernel.run(
...     queries=packed_queries,      # [total_tokens, num_q_heads, head_dim]
...     key_cache=key_cache,         # [num_blocks, block_size, num_kv_heads, head_dim]
...     value_cache=value_cache,     # [num_blocks, block_size, num_kv_heads, head_dim]
...     kv_lens=context_lengths,     # [num_seqs]
...     block_tables=block_tables,   # [num_seqs, max_blocks_per_seq]
...     query_start_loc=cu_seqlens,  # [num_seqs + 1]
...     cfg=UnifiedAttentionConfig(),
... )
candidate_cfgs(inv: Invocation[UnifiedAttentionConfig, Array])[source]#

Return candidate configurations for autotuning.

This operation exposes the main tuning knobs directly via the config, so autotuning is avoided by default to reduce overhead.

Parameters

inv – Invocation containing the input arguments and metadata.

Returns

Empty list (autotuning disabled for this kernel).

get_impl(cfg: UnifiedAttentionConfig)[source]#

Get the platform-specific implementation.

Parameters

cfg – Configuration specifying platform and backend preferences.

Returns

Callable implementation function from the kernel registry.

heuristic_cfg(inv: Invocation[UnifiedAttentionConfig, Array]) UnifiedAttentionConfig[source]#

Generate default configuration based on input characteristics.

Follows vLLM’s decode kernel selection heuristic to determine the sequence length threshold for switching between 2D and 3D grid launches.

Parameters

inv – Invocation containing the input arguments and metadata.

Returns

  • seq_threshold_3d: Sequence length above which 3D grid is used

  • num_par_softmax_segments: Number of parallel softmax reduction segments

Return type

UnifiedAttentionConfig with heuristically determined parameters

run(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: UnifiedAttentionConfig) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute unified paged attention.

Parameters
  • queries – Packed query tensor of shape [total_tokens, num_q_heads, head_dim]. Contains all query tokens from all sequences concatenated together.

  • key_cache – Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated blocks storing key vectors for all sequences.

  • value_cache – Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated blocks storing value vectors for all sequences.

  • kv_lens – Context lengths per sequence of shape [num_seqs]. Number of valid KV tokens for each sequence.

  • block_tables – Block index mapping of shape [num_seqs, max_blocks_per_seq]. Maps logical block indices to physical block indices in the cache.

  • query_start_loc – Cumulative query token counts of shape [num_seqs + 1]. query_start_loc[i] gives the starting token index for sequence i.

  • softmax_scale – Scaling factor for attention scores. If None, uses 1/sqrt(head_dim).

  • causal – Whether to apply causal masking. Default True.

  • sliding_window – Optional sliding window size for local attention. If provided, each query only attends to the last sliding_window KV positions.

  • logits_soft_cap – Optional soft cap for attention logits (Gemma-2 style). Applies tanh-based capping: logits = soft_cap * tanh(logits / soft_cap).

  • alibi_slopes – Optional ALiBi slopes per head of shape [num_q_heads]. Adds position-dependent bias: bias[i,j] = slope * (j - i).

  • qq_bias – Optional query-query bias of shape [num_query_tokens, num_query_tokens]. Added directly to attention logits between query positions.

  • attention_sink – Optional attention sink values per head of shape [num_q_heads]. Adds constant attention to the first token for streaming inference stability.

  • platform – Override platform selection. One of “triton”, “pallas”, “cuda”, “xla”, “auto”.

  • cfg – Kernel configuration with tuning parameters.

Returns

Output tensor of shape [total_tokens, num_q_heads, head_dim] with attention results.

ejkernel.modules.operations.unified_attention.unified_attention(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], /, *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.UnifiedAttentionConfig | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#

Execute unified paged attention with automatic platform selection.

This is the main entry point for vLLM-style unified attention, suitable for inference serving workloads with paged KV caches. It handles both prefill and decode phases efficiently using a single unified kernel.

The function automatically selects the optimal platform (Triton, Pallas, XLA) based on available hardware and applies heuristic-based configuration tuning.

Parameters
  • queries – Packed query tensor of shape [total_tokens, num_q_heads, head_dim]. All query tokens from all sequences are concatenated together without padding.

  • key_cache – Paged key cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated memory blocks storing key vectors. Blocks are shared across sequences and mapped via block_tables.

  • value_cache – Paged value cache of shape [num_blocks, block_size, num_kv_heads, head_dim]. Pre-allocated memory blocks storing value vectors, with same layout as key_cache.

  • kv_lens – Context lengths per sequence of shape [num_seqs]. Specifies how many KV tokens are valid for each sequence.

  • block_tables – Block index mapping of shape [num_seqs, max_blocks_per_seq]. Maps each sequence’s logical block indices to physical block indices in the cache. For sequence i, block_tables[i, j] gives the physical block index for logical block j.

  • query_start_loc – Cumulative query token counts of shape [num_seqs + 1]. Defines the boundaries of each sequence in the packed queries tensor. Sequence i’s queries span indices [query_start_loc[i], query_start_loc[i+1]).

  • softmax_scale – Scaling factor for attention scores. Default: 1/sqrt(head_dim).

  • causal – Whether to apply causal masking. Default: True. When True, each query can only attend to KV positions at or before its position.

  • sliding_window – Optional sliding window size for local attention. If provided, each query only attends to the most recent sliding_window KV positions.

  • logits_soft_cap – Optional soft cap for attention logits (Gemma-2 style). Applies: logits = soft_cap * tanh(logits / soft_cap) to prevent extreme values.

  • alibi_slopes – Optional ALiBi position bias slopes per head [num_q_heads]. Adds linear position-dependent bias to attention scores.

  • qq_bias – Optional query-query attention bias [num_query_tokens, num_query_tokens]. Directly added to attention logits for query-to-query interactions.

  • attention_sink – Optional attention sink values per head [num_q_heads]. Provides stable attention to the first token for streaming/infinite context.

  • platform – Override automatic platform selection. One of: - “triton”: Force Triton (GPU) - “pallas”: Force Pallas (TPU/GPU) - “xla”: Force XLA fallback - “cuda”: Force CUDA - “auto”: Automatic selection (default) - None: Same as “auto”

  • cfg – Optional configuration override. If None, uses heuristic-based defaults.

Returns

Output tensor of shape [total_tokens, num_q_heads, head_dim] containing the attention results for all sequences, packed in the same order as queries.

Example

>>> import jax.numpy as jnp
>>> from ejkernel.modules import unified_attention
>>>
>>> # Setup for 2 sequences with different lengths
>>> total_tokens = 5  # seq1 has 2 tokens, seq2 has 3 tokens
>>> queries = jnp.ones((total_tokens, 8, 64))  # 8 heads, dim 64
>>> key_cache = jnp.ones((100, 16, 2, 64))     # 100 blocks, block_size=16, 2 KV heads
>>> value_cache = jnp.ones((100, 16, 2, 64))
>>> kv_lens = jnp.array([32, 48])              # context lengths
>>> block_tables = jnp.array([[0, 1], [2, 3, 4]])  # block mappings
>>> query_start_loc = jnp.array([0, 2, 5])     # [0, 2) for seq1, [2, 5) for seq2
>>>
>>> output = unified_attention(
...     queries, key_cache, value_cache,
...     kv_lens, block_tables, query_start_loc,
...     causal=True,
... )

Note

This kernel is optimized for inference and does not support backward passes. For training, use flash_attention instead.