ejkernel.kernels._xla.unified_attention._interface

ejkernel.kernels._xla.unified_attention._interface#

Unified attention interface for paged KV cache with mixed workloads.

This module provides the public API for unified attention that handles ragged batches with paged key-value caches. Supports sliding window, ALiBi slopes, and attention sink features.

ejkernel.kernels._xla.unified_attention._interface.unified_attention(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_blocks block_size num_kv_heads head_dim'], kv_lens: Int32[jaxlib._jax.Array, 'num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'num_seqs max_blocks_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'num_seqs_plus_1'], *, softmax_scale: float | None = None, causal: bool = True, sliding_window: int | None = None, logits_soft_cap: float | None = None, seq_threshold_3d: int | None = None, num_par_softmax_segments: int | None = None, alibi_slopes: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, qq_bias: jaxtyping.Float[jaxlib._jax.Array, 'num_query_tokens num_query_tokens'] | None = None, attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, num_warps: int | None = None, num_stages: int | None = None) Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'][source]#