ejkernel.kernels._triton.unified_attention._triton_impl_fwd
ejkernel.kernels._triton.unified_attention._triton_impl_fwd
-
ejkernel.kernels._triton.unified_attention._triton_impl_fwd.unified_attention_triton(*, queries: Array, key_cache: Array, value_cache: Array, block_tables: Array, kv_lens: Array, query_start_loc: Array, softmax_scale: float | None, causal: bool, sliding_window: int | None, logits_soft_cap: float | None, seq_threshold_3d: int | None, num_par_softmax_segments: int | None, alibi_slopes: jax.jaxlib._jax.Array | None, qq_bias: jax.jaxlib._jax.Array | None, attention_sink: jax.jaxlib._jax.Array | None, num_warps: int | None, num_stages: int | None) → Array[source]