ejkernel.kernels._triton.native_sparse_attention._triton_impl_fwd
ejkernel.kernels._triton.native_sparse_attention._triton_impl_fwd
-
ejkernel.kernels._triton.native_sparse_attention._triton_impl_fwd.fwd_triton_impl(q: Array, k: Array, v: Array, block_indices: Array, block_counts: jax.jaxlib._jax.Array | int, block_size: int, softmax_scale: float, cu_seqlens: jax.jaxlib._jax.Array | None = None, token_indices: jax.jaxlib._jax.Array | None = None)[source]
-
ejkernel.kernels._triton.native_sparse_attention._triton_impl_fwd.nsa_topk(q: Array, k: Array, lse: Array, block_counts: jax.jaxlib._jax.Array | int, block_size: int = 64, softmax_scale: float | None = None, cu_seqlens: jax.jaxlib._jax.Array | None = None) → Array[source]