ejkernel.kernels._triton.blocksparse_attention._mask#

Sparse mask generation for block-sparse attention.

This module provides utilities for creating and managing sparse attention masks at the block level. Instead of computing full token-level attention masks, block-sparse attention uses coarse-grained masks that specify which blocks of queries attend to which blocks of keys.

The SparseMask dataclass encapsulates four boundary arrays that define: 1. lower_bounds: First KV block each query block attends to (sparse pattern) 2. upper_bounds: Last KV block each query block attends to (sparse pattern) 3. lower_full_bounds: First fully-attended KV block (optimization) 4. upper_full_bounds: Last fully-attended KV block (optimization)

These masks enable significant performance optimizations: - Skip computation for masked-out blocks entirely - Use faster kernels for fully-attended blocks (no masking logic) - Support causal masking, sliding windows, and custom patterns - Handle segment IDs for packed variable-length sequences

The mask computation is performed on GPU using Triton kernels for efficiency, and the resulting masks are used by both forward and backward attention passes.

Key functions: - create_sparsity_mask: High-level API for generating masks - define_sparse_mask_fn: Core mask generation logic - SparseMask.from_inputs: Create mask from positions and segment IDs

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.blocksparse_attention._mask import create_sparsity_mask
>>>
>>> batch, seq_len = 2, 512
>>> q_positions = jnp.arange(seq_len).reshape(1, -1).repeat(batch, 0)
>>> kv_positions = jnp.arange(seq_len).reshape(1, -1).repeat(batch, 0)
>>> q_segment_ids = jnp.zeros((batch, seq_len), dtype=jnp.int32)
>>> kv_segment_ids = jnp.zeros((batch, seq_len), dtype=jnp.int32)
>>>
>>>
>>> masks = create_sparsity_mask(
...     q_positions, q_segment_ids,
...     kv_positions, kv_segment_ids,
...     q_blocksize=64, kv_blocksize=64,
...     causal=True
... )
class ejkernel.kernels._triton.blocksparse_attention._mask.SparseMask(lower_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]], upper_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]], lower_full_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]], upper_full_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]])[source]#

Bases: Mapping

Sparse attention mask at the block level.

This dataclass represents a sparse attention pattern by defining which blocks of keys/values each block of queries should attend to. It uses boundary arrays to efficiently encode the sparsity pattern without materializing a full mask.

lower_bounds#

First KV block index each query block attends to, shape (batch, 1, num_q_blocks). Defines the start of the attention range.

Type

Optional[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, jax._src.literals.TypedNdArray]]

upper_bounds#

Last KV block index (+1) each query block attends to, shape (batch, 1, num_q_blocks). Defines the end of the attention range.

Type

Optional[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, jax._src.literals.TypedNdArray]]

lower_full_bounds#

First KV block index that is fully attended (no partial masking), shape (batch, 1, num_q_blocks). Used for kernel optimization.

Type

Optional[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, jax._src.literals.TypedNdArray]]

upper_full_bounds#

Last KV block index (+1) that is fully attended, shape (batch, 1, num_q_blocks). Used for kernel optimization.

Type

Optional[Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, jax._src.literals.TypedNdArray]]

The bounds define half-open intervals [lower, upper) for each query block. Blocks outside these bounds are completely masked out. Blocks between lower_full_bounds and upper_full_bounds can use optimized kernels without per-token masking logic.

Example

A causal mask for a sequence divided into 4 blocks might have: - Block 0: [0, 1) with [0, 1) fully attended - Block 1: [0, 2) with [0, 2) fully attended - Block 2: [0, 3) with [0, 3) fully attended - Block 3: [0, 4) with [0, 4) fully attended

classmethod from_inputs(q_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], q_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_blocksize: int, q_blocksize: int, calculate_dkdv_mask: bool = False, causal: bool = True, window_left: int = -1, window_right: int = -1, mesh: jax._src.mesh.Mesh | None = None)[source]#

Create a SparseMask from query and key-value positions and segments.

This factory method generates a sparse attention mask by analyzing the positions and segment IDs of queries and keys/values. It automatically determines which blocks should attend to each other based on the specified attention pattern (causal, windowed, etc.).

Parameters
  • q_positions – Query token positions, shape (batch, q_seq_len).

  • q_segment_ids – Query segment IDs for packed sequences, shape (batch, q_seq_len).

  • kv_positions – Key/value token positions, shape (batch, kv_seq_len).

  • kv_segment_ids – Key/value segment IDs, shape (batch, kv_seq_len).

  • kv_blocksize – Size of key/value blocks in tokens.

  • q_blocksize – Size of query blocks in tokens.

  • calculate_dkdv_mask – If True, compute mask for gradient computation with respect to keys/values (backward pass).

  • causal – If True, apply causal masking (lower triangular).

  • window_left – Left window size for sliding window attention (-1 for unlimited).

  • window_right – Right window size for sliding window attention (-1 for unlimited).

  • mesh – Optional device mesh for distributed computation.

Returns

A SparseMask instance with computed boundary arrays.

Example

>>> import jax.numpy as jnp
>>> batch, seq_len = 2, 256
>>> positions = jnp.arange(seq_len).reshape(1, -1).repeat(batch, 0)
>>> segments = jnp.zeros((batch, seq_len), dtype=jnp.int32)
>>>
>>> mask = SparseMask.from_inputs(
...     positions, segments, positions, segments,
...     kv_blocksize=64, q_blocksize=64, causal=True
... )
from_tuple()#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
lower_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]]#
lower_full_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]]#
replace(**kwargs)#
to_tuple()#
upper_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]]#
upper_full_bounds: Optional[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]]#
values() an object providing a view on D's values#
ejkernel.kernels._triton.blocksparse_attention._mask.create_sparsity_mask(q_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], q_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], mesh: jax._src.mesh.Mesh | None = None, kv_blocksize: int = 64, q_blocksize: int = 64, causal: bool = True, window_left: int = -1, window_right: int = -1) tuple[ejkernel.kernels._triton.blocksparse_attention._mask.SparseMask, ...][source]#

Creates attention masks for forward and (optionally) backward block-sparse attention kernels.

This function generates the required attention masks based on the query and key-value (KV) positions and segment ids. The masks are used for both the forward and backward passes in flash attention to improve computational efficiency while respecting segment boundaries.

Parameters
  • q_positions (ArrayLike) – The positions of the query tokens of shape: (batch_size, query_seq_length).

  • q_segment_ids (ArrayLike) – Segment ids for query tokens of shape: (batch_size, query_seq_length).

  • kv_positions (ArrayLike) – The positions of the key and value tokens of shape: (batch_size, kv_seq_length).

  • kv_segment_ids (ArrayLike) – Segment ids for key abd value tokens of shape: (batch_size, kv_seq_length).

  • fwd_params (FlashAttentionParamsConfig | None, optional) – Parameters for the forward pass of the flash attention kernel. Defaults to parameters defined via get_default_flash_attention_params(backward=False).

  • bwd_params (FlashAttentionParamsConfig | None, optional) – Parameters for the backward pass of the flash attention kernel. Defaults to parameters defined via get_default_flash_attention_params(backward=True).

  • mesh (Mesh | None, optional) – Device mesh configuration for distributed execution. If None, it takes the mesh from the global context. Defaults is None.

Returns

A tuple containing:
  • The forward attention mask.

  • (Optional) The backward mask for dquery (if calc_bwd_mask is True).

  • (Optional) The backward mask for dkey and dvalue

    (if calc_bwd_mask is True).

Return type

tuple[SparseMask, …]

Notes

  • If calc_bwd_mask is True, masks for dquery, dkey, and dvalue are computed.

  • Defaults for fwd_params and bwd_params are set using

    get_default_flash_attention_params.

ejkernel.kernels._triton.blocksparse_attention._mask.define_sparse_mask_fn(q_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], q_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_positions: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_segment_ids: Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], kv_blocksize: int, q_blocksize: int, calculate_dkdv_mask: bool = False, causal: bool = True, window_left: int = -1, window_right: int = -1) SparseMask[source]#