ejkernel.kernels._triton.flash_attention._utilities#

ejkernel.kernels._triton.flash_attention._utilities.attention_pack_from_cu_static(x: Float[jaxlib._jax.Array, 'batch seq_max num_heads head_dim'], cum_seqlens: Int[jaxlib._jax.Array, 'batch_plus_one'], max_tokens: int | None = None) Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'][source]#

Packs variable-length batch using cum_seqlens into a single [1, T, H, D] tensor. T can be any static upper bound (e.g., B*S_max). Only the first cum_seqlens[-1] tokens will be written; the rest stay zero.

ejkernel.kernels._triton.flash_attention._utilities.attention_pack_with_static_shape(x: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], attention_mask: Bool[jaxlib._jax.Array, 'batch seq_len'], max_tokens: int | None = None) Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'][source]#

Pack attention tensor by removing padding based on attention mask. Uses a static maximum shape to be compatible with JIT.

ejkernel.kernels._triton.flash_attention._utilities.attention_unpack_with_static_shape(x: Float[jaxlib._jax.Array, '1 max_tokens num_heads head_dim'], cum_seqlens: Int[jaxlib._jax.Array, 'batch_plus_one'], batch_size: int, seqlen: int) Float[jaxlib._jax.Array, 'batch seqlen num_heads head_dim'][source]#

Unpack back into [B, seqlen, H, D] using cum_seqlens. The ‘seqlen’ is a static padded max length; tokens past end are left as zeros.

ejkernel.kernels._triton.flash_attention._utilities.basic_attention_refrence(q: Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'], k: Float[jaxlib._jax.Array, 'batch seq_len_k num_heads_kv head_dim'], v: Float[jaxlib._jax.Array, 'batch seq_len_k num_heads_kv head_dim'], attn_bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, query_padding_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch seq_len_q'] | None = None, key_padding_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch seq_len_k'] | None = None, dropout_prob: float = 0.0, dropout_key: jax.jaxlib._jax.Array | None = None, window_size: tuple[int, int] = (-1, -1), causal: bool = False, softcap: float = 0.0) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Reference implementation of attention for testing and validation.

Provides a standard JAX implementation of scaled dot-product attention with support for various masking options, useful for validating the optimized Triton kernels.

Parameters
  • q – Query tensor [batch, seq_len, num_heads, head_dim]

  • k – Key tensor [batch, seq_len_k, num_heads_kv, head_dim]

  • v – Value tensor [batch, seq_len_k, num_heads_kv, head_dim]

  • attn_bias – Optional attention bias tensor

  • query_padding_mask – Boolean mask for query positions

  • key_padding_mask – Boolean mask for key positions

  • dropout_prob – Dropout probability for attention weights

  • dropout_key – JAX random key for dropout

  • window_size – Local attention window (left, right)

  • causal – Whether to apply causal masking

  • softcap – Soft capping value for attention scores

Returns

Attention output with same shape as queries

Return type

jnp.ndarray

ejkernel.kernels._triton.flash_attention._utilities.calc_bias_strides(bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len_q seq_len_k'] | None, batch: int, nheads_q: int, QSeq: int, KSeq: int) tuple[int, int, int][source]#

Calculate memory strides for bias tensor with broadcasting support.

Validates bias tensor dimensions and computes appropriate strides for batch and head dimensions, supporting broadcasting when dimensions are 1.

Parameters
  • bias – Optional bias tensor with shape [batch, heads, QSeq, KSeq]

  • batch – Expected batch size

  • nheads_q – Number of query attention heads

  • QSeq – Query sequence length

  • KSeq – Key sequence length

Returns

(stride_bz, stride_bh, stride_bm, stride_bn) memory strides

Return type

tuple

Raises

ValueError – If bias dimensions are incompatible with expected shapes