ejkernel.kernels._triton.blocksparse_attention._utilities#
- ejkernel.kernels._triton.blocksparse_attention._utilities.attention_pack_from_cu_static(x: Float[jaxlib._jax.Array, 'batch seq_max num_heads head_dim'], cum_seqlens: Int[jaxlib._jax.Array, 'batch_plus_one'], max_tokens: int | None = None) Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'][source]#
Packs variable-length batch using cum_seqlens into a single [1, T, H, D] tensor. T can be any static upper bound (e.g., B*S_max). Only the first cum_seqlens[-1] tokens will be written; the rest stay zero.
- ejkernel.kernels._triton.blocksparse_attention._utilities.attention_pack_with_static_shape(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], attention_mask: Bool[jaxlib._jax.Array, 'batch seq_len'], max_tokens: int | None = None) Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'][source]#
Pack attention tensor by removing padding based on attention mask. Uses a static maximum shape to be compatible with JIT.
- ejkernel.kernels._triton.blocksparse_attention._utilities.attention_unpack_with_static_shape(x: Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'], cum_seqlens: Int[jaxlib._jax.Array, 'batch_plus_one'], batch_size: int, seqlen: int) Float[jaxlib._jax.Array, 'batch seqlen num_heads head_dim'][source]#
Unpack back into [B, seqlen, H, D] using cum_seqlens. The ‘seqlen’ is a static padded max length; tokens past end are left as zeros.
- ejkernel.kernels._triton.blocksparse_attention._utilities.basic_attention_refrence(q: Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], k: Float[jaxlib._jax.Array, 'batch seq_len_k num_heads_kv head_dim'], v: Float[jaxlib._jax.Array, 'batch seq_len_k num_heads_kv head_dim'], attn_bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, query_padding_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, key_padding_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, dropout_prob: float = 0.0, dropout_key: jax.jaxlib._jax.Array | None = None, window_size: tuple[int, int] = (-1, -1), causal: bool = False, softcap: float = 0.0) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Reference implementation of attention for testing and validation.
Provides a standard JAX implementation of scaled dot-product attention with support for various masking options, useful for validating the optimized Triton kernels.
- Parameters
q – Query tensor [batch, seq_len, num_heads, head_dim]
k – Key tensor [batch, seq_len_k, num_heads_kv, head_dim]
v – Value tensor [batch, seq_len_k, num_heads_kv, head_dim]
attn_bias – Optional attention bias tensor
query_padding_mask – Boolean mask for query positions
key_padding_mask – Boolean mask for key positions
dropout_prob – Dropout probability for attention weights
dropout_key – JAX random key for dropout
window_size – Local attention window (left, right)
causal – Whether to apply causal masking
softcap – Soft capping value for attention scores
- Returns
Attention output with same shape as queries
- Return type
jnp.ndarray
- ejkernel.kernels._triton.blocksparse_attention._utilities.calc_bias_strides(bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None, batch: int, nheads_q: int, QSeq: int, KSeq: int) tuple[int, int, int][source]#
Calculate memory strides for bias tensor with broadcasting support.
Validates bias tensor dimensions and computes appropriate strides for batch and head dimensions, supporting broadcasting when dimensions are 1.
- Parameters
bias – Optional bias tensor with shape [batch, heads, QSeq, KSeq]
batch – Expected batch size
nheads_q – Number of query attention heads
QSeq – Query sequence length
KSeq – Key sequence length
- Returns
(stride_bz, stride_bh, stride_bm) memory strides
- Return type
tuple
- Raises
ValueError – If bias dimensions are incompatible with expected shapes
- ejkernel.kernels._triton.blocksparse_attention._utilities.pad_to_block_size(inputs: collections.abc.Sequence[jax.jaxlib._jax.Array] | None, indexs: jax.jaxlib._jax.Array | None, segment_ids: jax.jaxlib._jax.Array | None, block_size: int, pos_fill_value: int, transposed_inputs: bool = False)[source]#