ejkernel.kernels._triton.unified_attention._interface

ejkernel.kernels._triton.unified_attention._interface#

vLLM-style unified (paged) attention implemented in Triton.

This is a JAX/Triton port of vLLM’s triton_unified_attention.py, adapted to ejkernel’s triton_call interface.

Core inputs: - queries: packed ragged queries, shape [total_tokens, num_q_heads, head_dim] - key_cache/value_cache: paged KV cache, shape [num_blocks, block_size, num_kv_heads, head_dim] - query_start_loc: cumulative query offsets, shape [num_seqs + 1] (int32) - kv_lens: KV lengths per sequence, shape [num_seqs] (int32) - block_tables: mapping [num_seqs, max_blocks_per_seq] (int32)

Supported features (inference-only): - causal masking (required) - optional sliding window via sliding_window (window length) - optional logit softcap (logits_soft_cap) - optional attention sink (attention_sink): contributes to softmax normalizer only - optional ALiBi slopes (alibi_slopes) - optional query-query bias (qq_bias) for TreeAttention-like decode

ejkernel.kernels._triton.unified_attention._interface.unified_attention(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, num_warps: int | None = None, num_stages: int | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#