ejkernel.kernels._triton.ragged_page_attention_v3._interface
ejkernel.kernels._triton.ragged_page_attention_v3._interface
-
ejkernel.kernels._triton.ragged_page_attention_v3._interface.ragged_page_attention_v3(queries: Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], keys: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], values: Float[jaxlib._jax.Array, 'total_tokens num_kv_heads head_dim'], kv_cache: Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded'], kv_lens: Int32[jaxlib._jax.Array, 'max_num_seqs'], block_tables: Int32[jaxlib._jax.Array, 'max_num_seqs_times_pages_per_seq'], query_start_loc: Int32[jaxlib._jax.Array, 'max_num_seqs_plus_1'], distribution: Int32[jaxlib._jax.Array, '3'], attention_sink: jaxtyping.Float[jaxlib._jax.Array, 'num_q_heads'] | None = None, *, softmax_scale: float = 1.0, sliding_window: int | None = None, logits_soft_cap: float | None = None, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None) → tuple[jaxtyping.Float[jaxlib._jax.Array, 'total_tokens num_q_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'num_pages page_size num_kv_heads_x2_per_kv_packing kv_packing head_dim_padded']][source]