ejkernel.kernels._xla.ragged_page_attention_v3._kernel
ejkernel.kernels._xla.ragged_page_attention_v3._kernel
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.align_to(x, a)[source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.cdiv(a, b)[source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.get_dtype_bitwidth(dtype)[source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.get_dtype_packing(dtype)[source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.merge_kv(k: Array, v: Array) → Array[source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.ragged_paged_attention(queries: Array, keys: Array, values: Array, kv_cache: Array, kv_lens: Array, block_tables: Array, query_start_loc: Array, distribution: Array, attention_sink: jax.jaxlib._jax.Array | None = None, *, softmax_scale: float = 1.0, sliding_window: int | None = None, logits_soft_cap: float | None = None, mask_value: float | None = -2.381976426469702e+38, q_scale: float | None = None, k_scale: float | None = None, v_scale: float | None = None, chunk_prefill_size: int | None = None, num_kv_pages_per_block: int | None = None, num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None) → tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]
-
ejkernel.kernels._xla.ragged_page_attention_v3._kernel.static_validate_inputs(q, k, v, kv_cache, kv_lens, block_tables, query_start_loc, distribution, *, softmax_scale=1.0, sliding_window=None, logits_soft_cap=None, mask_value=-2.381976426469702e+38, q_scale=None, k_scale=None, v_scale=None, chunk_prefill_size=None, num_kv_pages_per_block=None, num_queries_per_block=None, vmem_limit_bytes=None)[source]