ejkernel.types.mask

Contents

ejkernel.types.mask#

Attention Mask Management for JAX/Flax Models.

This module provides comprehensive tools for creating, manipulating, and converting attention masks in transformer models. It supports various attention patterns and provides efficient conversions between different mask representations.

Key Components:
  • MaskInfo: Main dataclass for managing attention masks and segment IDs

  • Conversion functions: Convert between masks and segment IDs

  • Attention patterns: Causal, sliding window, chunked, token-type-based

  • Distributed support: Sharding specifications for multi-device training

  • Visualization: Debug and understand attention patterns

Common Usage:
>>>
>>> mask_info = MaskInfo.from_segments(segment_ids)
>>>
>>>
>>> mask_info = MaskInfo.from_attention_mask(attention_mask)
>>>
>>>
>>> causal_mask_info = mask_info.apply_causal()
>>>
>>>
>>> bias = mask_info.bias
Mask Representations:
  1. Attention Mask: 4D boolean/int array (batch, heads, q_len, kv_len) - True/1 = valid attention, False/0 = masked

  2. Segment IDs: 2D int32 arrays (batch, seq_len) - Non-negative = segment membership - -1 = padding tokens

Debug Mode:

Enable debug tracing to see which functions are being called:

>>> import os
>>> os.environ['EJKERNEL_MASK_DEBUG'] = '1'  # Enable debug mode
>>> # Now all MaskInfo operations will print debug traces
>>> mask_info = MaskInfo.from_segments(segment_ids)
[MaskInfo Debug] Calling type.from_segments()

Set EJKERNEL_MASK_DEBUG to ‘0’ or remove it to disable debug output.

See also

  • mask_to_segment_ids(): Convert masks to segment IDs

  • segment_ids_to_mask(): Convert segment IDs to masks

  • MaskInfo: Main class for mask management

class ejkernel.types.mask.MaskInfo(_attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 q k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 q k'] | None = None, _q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch q'] | None = None, _kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch k'] | None = None, _cu_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, _cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, causal_mask_baked_in: bool = False, sliding_window_baked_in: bool = False, chunked_mask_baked_in: bool = False, token_type_ids_baked_in: bool = False, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp')[source]#

Bases: object

Container for attention mask information with utilities for conversion and manipulation.

This dataclass holds both attention masks and their corresponding segment IDs, along with optional position indices for queries and keys/values. It provides convenient methods for conversion between representations and extracting derived information.

attention_mask#

The 2D/3D/4D boolean or integer attention mask

q_segment_ids#

Query segment IDs (batch, qlen) where -1 indicates padding

kv_segment_ids#

Key-value segment IDs (batch, kvlen) where -1 indicates padding

cu_seqlens_q#

Cumulative sequence lengths for queries (batch+1,)

cu_seqlens_kv#

Cumulative sequence lengths for keys/values (batch+1,)

q_positions#

Query position indices (batch, qlen) for positional embeddings

Type

jaxtyping.Int[jaxlib._jax.Array, ‘batch qlen’] | None

kv_positions#

Key-value position indices (batch, kvlen) for positional embeddings

Type

jaxtyping.Int[jaxlib._jax.Array, ‘batch kvlen’] | None

causal_mask_baked_in#

Flag indicating if causal masking has been applied

Type

bool

sliding_window_baked_in#

Flag indicating if sliding window masking has been applied

Type

bool

chunked_mask_baked_in#

Flag indicating if chunked masking has been applied

Type

bool

token_type_ids_baked_in#

Flag indicating if token type ID masking has been applied

Type

bool

apply_causal(offset: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] = 0) MaskInfo[source]#

Apply causal (autoregressive) masking to the attention pattern.

Restricts attention so that each query position can only attend to key positions at or before its own position (plus an optional offset). The segment IDs are preserved to maintain grouping structure.

Parameters

offset – Position offset for causal masking. Can be: - int: Scalar offset applied to all batch elements - Array of shape (batch,): Per-batch offsets Default: 0 - offset=0: Standard causal (q_i attends to kv_j where j <= i) - offset>0: Allows attending to future tokens (j <= i + offset) - offset<0: More restrictive causal (j <= i + offset)

Returns

New MaskInfo with causal constraint applied while preserving segment IDs

Raises

ValueError – If mask dimensions are unknown

Example

>>> segment_ids = jnp.array([[1, 1, 1, 1]])
>>> mask_info = MaskInfo.from_segments(segment_ids)
>>> causal_mask = mask_info.apply_causal()
>>>
>>> # Per-batch offsets
>>> offsets = jnp.array([0, 1, 2])
>>> causal_mask = mask_info.apply_causal(offset=offsets)
apply_chunked(chunk_size: int, offset: int = 0) MaskInfo[source]#

Apply chunked causal attention and ALWAYS update q/kv segment IDs to chunk IDs.

  • New segment IDs are the chunk indices + 1 (padding stays -1)

  • Attention mask becomes: existing_mask AND (same_chunk AND causal)

  • This makes segment IDs the canonical representation of chunk structure.

Note: segment IDs encode chunk grouping; causal direction still requires positions/rule.

Parameters
  • chunk_size – Positive chunk size.

  • offset – Optional causal offset (default 0).

Returns

New MaskInfo with updated attention_mask and updated segment IDs.

apply_kv_lengths(kv_lengths: Int[jaxlib._jax.Array, 'batch'], *, q_len: int | None = None, sliding_window: int | None = None) MaskInfo[source]#

Mask out key/value positions beyond per-example lengths and keep a trailing query window.

The method expects a 4D attention mask (batch, heads, q_len, kv_len). For each batch item:
  1. KV positions with index >= kv_lengths[b] are masked out.

  2. The query dimension is sliced to the last q_len rows, starting at kv_lengths[b] - q_len.

  3. If sliding_window is provided and smaller than the current KV dimension, only the most recent sliding_window columns are kept.

Segment IDs and position arrays are reused unchanged; only the materialized attention mask is updated.

Parameters
  • kv_lengths – Integer array of shape (batch,) with the number of valid KV tokens per batch element. The implementation assumes kv_lengths[b] >= q_len and does not clamp indices.

  • q_len – Number of query rows to keep. Must be specified and should be <= kv_lengths[b] for all b.

  • sliding_window – Optional maximum number of KV columns to retain after masking. If None, keeps all remaining KV positions.

Returns

New MaskInfo whose attention_mask has shape (batch, 1, q_len, effective_kv_len), where effective_kv_len equals kv_len when sliding_window is None and otherwise min(kv_len, sliding_window).

Raises

ValueError – If the attention mask cannot be materialized.

apply_sliding_window(sliding_window: int | tuple[int, int], *, offset: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] = 0, mode: Optional[Literal['default', 'decode', 'prefill']] = None, index: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None) MaskInfo[source]#

Apply sliding window attention to the attention pattern.

Restricts attention so that each query position can only attend to key positions within a specified window around its position. This is useful for local attention patterns where distant tokens are masked out.

Parameters
  • sliding_window

    Window size specification: - int: Symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric windows

    • left_window: How many positions to the left can be attended to

    • right_window: How many positions to the right can be attended to

  • offset – Row offset for sliding window calculation. Can be: - int: Scalar offset applied to all batch elements - Array of shape (batch,): Per-batch offsets Default: 0

  • mode – Attention mode for dynamic slicing: - “default” or None: Standard sliding window without slicing - “decode”: Decode mode - slices KV to window size around current index - “prefill”: Prefill mode - slices to last sliding_window positions

  • index – Current position index (required for “decode” mode). Can be: - int: Scalar index applied to all batch elements - Array of shape (batch,): Per-batch indices

Returns

New MaskInfo with sliding window constraint applied while preserving segment IDs

Raises

ValueError – If mask dimensions are unknown, window size is invalid, or index is missing in decode mode

Example

>>> segment_ids = jnp.array([[1, 1, 1, 1, 1, 1]])
>>> mask_info = MaskInfo.from_segments(segment_ids)
>>>
>>> # Symmetric window: each position attends to 2 positions left and 2 right
>>> windowed_mask = mask_info.apply_sliding_window(2)
>>>
>>> # Asymmetric window: 3 left, 1 right
>>> windowed_mask = mask_info.apply_sliding_window((3, 1))
>>>
>>> # Decode mode at position 5 with window size 3
>>> decode_mask = mask_info.apply_sliding_window(3, mode="decode", index=5)
>>>
>>> # Decode mode with per-batch indices
>>> batch_indices = jnp.array([5, 7, 3])
>>> decode_mask = mask_info.apply_sliding_window(3, mode="decode", index=batch_indices)
>>>
>>> # Prefill mode with window size 4
>>> prefill_mask = mask_info.apply_sliding_window(4, mode="prefill")
apply_token_type_ids(token_type_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'] | tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len']], *, combine: Literal['union', 'intersect', 'replace'] = 'union', zero_policy: Literal['q', 'kv', 'both', 'none'] = 'q', update_segment_ids: bool | None = None) MaskInfo[source]#

Integrate token_type_ids into the attention pattern.

  • Builds an equality mask between q and kv token types.

  • Optionally treats token_type_id == 0 as “disabled” (no token-type matching)

on the query side, kv side, both, or neither (zero_policy). - Combines with the current attention mask by union/intersect/replace. - Optionally updates segment IDs to reflect token types (0 -> -1 padding).

Parameters
  • token_type_ids

    • self-attn: (batch, q_len)

    • cross-attn: (q_token_type_ids, kv_token_type_ids)

  • combine – How to combine with existing mask: - “union”: base_mask OR token_type_mask (matches your old snippet) - “intersect”: base_mask AND token_type_mask - “replace”: token_type_mask only

  • zero_policy

    • “q”: treat q==0 as disabled (no token-type matching for those queries) [matches old code]

    • ”kv”: treat kv==0 as disabled (no matching into those keys/values)

    • ”both”: treat 0 as disabled on both sides

    • ”none”: don’t treat 0 specially

  • update_segment_ids

    • If None: defaults to False for “union” (cannot encode union in seg-ids),

    and True for “intersect”/”replace”. - If True: set q/kv segment IDs from token types with 0 -> -1. - If False: keep existing segment IDs.

Returns

New MaskInfo with updated attention_mask (and optionally updated segment IDs).

property attention_mask: jax.jaxlib._jax.Array | None#
property baked_in_masks: dict[str, bool]#

Get a dictionary of all baked-in mask operation flags.

Returns

  • ‘causal’: Whether causal masking has been applied

  • ’sliding_window’: Whether sliding window masking has been applied

  • ’chunked’: Whether chunked masking has been applied

  • ’token_type_ids’: Whether token type ID masking has been applied

Return type

Dictionary mapping operation names to their baked-in status

Example

>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 1]]))
>>> mask_info = mask_info.apply_causal()
>>> mask_info.baked_in_masks
{'causal': True, 'sliding_window': False, 'chunked': False, 'token_type_ids': False}
batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp')#
property batch_size: int | None#

Get batch size from available data.

Infers the batch dimension from either segment IDs or attention mask.

Returns

Batch size if available, None otherwise

property bias#

Create attention bias from the mask (convenience property).

Returns an attention bias tensor where valid attention positions are 0.0 and masked positions are set to the minimum float value for the dtype.

Returns

Attention bias array with dtype float32

causal_mask_baked_in: bool = False#
chunked_mask_baked_in: bool = False#
create_bias(dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]#

Create attention bias from the mask.

Converts the boolean attention mask into an additive bias tensor suitable for attention score computation. Valid positions (mask=True) get 0.0, while masked positions (mask=False) get a large negative value (dtype.min).

Parameters

dtype – Output dtype for the bias tensor. Default: jnp.float32

Returns

  • Valid attention positions: 0.0

  • Masked positions: jnp.finfo(dtype).min

Return type

Attention bias array where

Example

>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]]))
>>> bias = mask_info.create_bias(dtype=jnp.float32)
>>>
static create_chunked_attention_mask(chunk_size: int, q_len: int, kv_len: int | None = None, offset: int = 0, dtype=<class 'jax.numpy.bool'>) Array[source]#

Create a chunked causal attention mask (static method).

Generates a 2D attention mask where attention is restricted to tokens within the same chunk, with causal ordering enforced within chunks.

Parameters
  • chunk_size – Size of each chunk (must be positive)

  • q_len – Query sequence length

  • kv_len – Key-value sequence length. If None, uses q_len

  • offset – Causal offset. Default: 0

  • dtype – Output dtype. Default: jnp.bool_

Returns

2D attention mask of shape (q_len, kv_len) with chunked causal pattern

Raises

ValueError – If chunk_size is not positive

Example

>>> mask = MaskInfo.create_chunked_attention_mask(
...     chunk_size=4, q_len=8, kv_len=8
... )
>>> mask.shape
(8, 8)
property cu_seqlens_kv: jax.jaxlib._jax.Array | None#
property cu_seqlens_q: jax.jaxlib._jax.Array | None#
classmethod dynamic_init(*, mask_info: ejkernel.types.mask.MaskInfo | None = None, input_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seqlen'] | None = None, inputs_embeds: jaxtyping.Float[jaxlib._jax.Array, 'batch seqlen dim'] | None = None, attention_mask: jaxtyping.Int[jaxlib._jax.Array, 'batch seqlen'] | jaxtyping.Bool[jaxlib._jax.Array, 'batch seqlen'] | None = None) MaskInfo[source]#

Dynamically initialize a MaskInfo from various input sources.

This is a convenience factory method that creates a MaskInfo instance from different types of inputs commonly available in transformer models. It prioritizes existing mask_info, then constructs one from attention_mask or input shapes.

Parameters
  • mask_info – Pre-existing MaskInfo to return as-is. If provided, other arguments are ignored.

  • input_ids – Token IDs array with shape (batch, seq_len). Used to infer shape if mask_info and attention_mask are not provided.

  • inputs_embeds – Token embeddings array with shape (batch, seq_len, dim). Used to infer shape if mask_info, attention_mask, and input_ids are not provided.

  • attention_mask – Attention mask array with shape (batch, seq_len). Values should be: - 1/True for valid (non-padding) tokens - 0/False for padding tokens If not provided, creates an all-ones mask (no padding).

Returns

MaskInfo instance constructed from the provided inputs

Raises

ValueError – If insufficient information is provided (no valid inputs)

Example

>>> input_ids = jnp.array([[1, 2, 3, 0], [4, 5, 0, 0]])
>>> attn_mask = jnp.array([[1, 1, 1, 0], [1, 1, 0, 0]])
>>> mask_info = MaskInfo.dynamic_init(input_ids=input_ids, attention_mask=attn_mask)
>>> mask_info.shape
(2, 4, 4)

Notes

  • This method is useful for model implementations where mask format may vary

  • Automatically converts 2D attention masks to segment-based representation

  • Higher-dimensional masks are handled via from_attention_mask()

classmethod from_attention_mask(attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'], q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, *, mask_is_valid: bool = True, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#

Create MaskInfo from an existing attention mask.

For 2D masks this treats the input as a padding mask (1/True = valid, 0/False = padding), converts it to segment IDs (0 for valid tokens, -1 for padding), and materializes a broadcasted 4D pairwise mask.

For 3D/4D masks, padding-style segment IDs are extracted by checking which Q positions attend to at least one KV position (and vice versa). This gives valid/padding information without attempting to recover full segment structure.

Parameters
  • attention_mask – Attention mask array. Supported shapes: - (batch, seqlen): 2D padding mask (token mask) - (batch, qlen, kvlen): 3D batched mask - (batch, heads, qlen, kvlen): 4D multi-head mask Values: True/1 = valid attention, False/0 = masked (unless mask_is_valid=False)

  • q_positions – Optional query position indices (batch, qlen)

  • kv_positions – Optional key-value position indices (batch, kvlen)

  • mask_is_valid – If False, treats True/1 entries as masked-out (disallowed) positions and inverts the mask (PyTorch-style attn_mask semantics).

Returns

MaskInfo with derived segment IDs, original attention mask, and optional positions

Raises

ValueError – If attention_mask is not 2D, 3D, or 4D

Example

>>> mask = jnp.array([[[1, 1, 0], [1, 1, 0], [0, 0, 1]]])
>>> mask_info = MaskInfo.from_attention_mask(mask)
>>> mask_info.q_segment_ids.shape
(1, 3)
classmethod from_cu_seqlens(cu_seqlens_q: Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#

Create a padding-style MaskInfo from cumulative sequence lengths.

This reconstructs 2D padding masks (valid tokens are a prefix) and stores a compact padding-style segment-id representation (0 for valid tokens, -1 for padding). The pairwise attention mask can be materialized on demand via attention_mask.

classmethod from_random(batch_size: int, q_len: int, kv_len: int | None = None, sparsity: float = 0.5, seed: int = 0, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#

Create MaskInfo with random attention pattern.

Generates a random binary attention mask with specified sparsity level. Useful for testing, experimentation, and studying sparse attention patterns.

Parameters
  • batch_size – Batch size

  • q_len – Query sequence length

  • kv_len – Key-value sequence length. If None, uses q_len (self-attention)

  • sparsity – Fraction of attention positions to mask out (0.0 = full attention, 1.0 = fully masked). Default: 0.5 (50% masked)

  • seed – Random seed for reproducibility. Default: 0

  • q_positions – Optional query position indices (batch, qlen)

  • kv_positions – Optional key-value position indices (batch, kvlen)

Returns

MaskInfo with random attention pattern and optional positions

Example

>>>
>>> mask_info = MaskInfo.from_random(
...     batch_size=2,
...     q_len=128,
...     sparsity=0.7,
...     seed=42
... )
>>> mask_info.attention_mask.shape
(2, 1, 128, 128)
>>>
>>> mask_info = MaskInfo.from_random(
...     batch_size=1,
...     q_len=64,
...     kv_len=128,
...     sparsity=0.5,
...     seed=0
... )
>>> mask_info.attention_mask.shape
(1, 1, 64, 128)
classmethod from_segments(q_segment_ids: Int[jaxlib._jax.Array, 'batch qlen'], kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp', is_attn_mask: bool = False) MaskInfo[source]#

Create MaskInfo from segment IDs.

Constructs a MaskInfo instance from query and key-value segment IDs, automatically generating the corresponding attention mask. Segment IDs group tokens that can attend to each other (same segment ID = can attend).

Parameters
  • q_segment_ids – Query segment IDs of shape (batch, qlen). Values should be: - Non-negative integers: segment membership (0, 1, 2, …) - -1: padding tokens

  • kv_segment_ids – Key-value segment IDs of shape (batch, kvlen). If None, uses q_segment_ids (self-attention case). Values follow same convention as q_segment_ids.

  • q_positions – Optional query position indices (batch, qlen) for positional embeddings

  • kv_positions – Optional key-value position indices (batch, kvlen) for positional embeddings

  • batch_axis_name – Axis name(s) for batch dimension in distributed sharding. Default: (“dp”, “fsdp”)

  • qheads_axis_name – Axis name(s) for query heads dimension in distributed sharding. Default: “tp”

  • kvheads_axis_name – Axis name(s) for key-value heads dimension in distributed sharding. Default: “tp”

  • sequence_axis_name – Axis name(s) for sequence dimension in distributed sharding. Default: “sp”

Returns

MaskInfo with segment IDs, computed attention mask, optional positions, and sharding configuration

Example

>>> q_seg = jnp.array([[1, 1, 2, 2, -1]])
>>> mask_info = MaskInfo.from_segments(q_seg)
>>> mask_info.attention_mask.shape
(1, 1, 5, 5)
static get_empty_sharding() MaskSharding[source]#

Create an empty MaskSharding with all specs set to None.

Useful as a default or placeholder when no sharding is needed.

Returns

MaskSharding with all fields set to None

get_or_compute_attention_mask(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#

Get attention mask, using cached data when available and deriving from segment IDs otherwise.

If a materialized attention mask is already stored, it is returned (with dtype conversion if needed). Otherwise, the mask is constructed from the available segment IDs.

Parameters

dtype – Desired output dtype (default: bool)

Returns

Attention mask array

Raises

ValueError – If both attention_mask and segment_ids are None

get_or_compute_positions() tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None, jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None][source]#

Get position arrays, computing them if not already available.

Generates position indices for queries and keys/values when not explicitly provided. Position arrays are useful for positional embeddings and rotary position embeddings (RoPE).

Returns

  • q_positions: (batch, qlen) position indices for queries, or None if dimensions unknown

  • kv_positions: (batch, kvlen) position indices for keys/values, or None if dimensions unknown

Return type

Tuple of (q_positions, kv_positions) where

Example

>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]]))
>>> q_pos, kv_pos = mask_info.get_or_compute_positions()
>>> q_pos.shape
(1, 4)
>>> kv_pos[0]
Array([0, 1, 2, 3], dtype=int32)
get_or_compute_qkv_cu_seqlens(*, out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>, max_segments: int = 128) tuple[jaxtyping.Int[jaxlib._jax.Array, 'max_segments_plus_1'], jaxtyping.Int[jaxlib._jax.Array, 'max_segments_plus_1']][source]#

Derive (cu_seqlens_q, cu_seqlens_kv) from the available mask representation.

Prefers segment IDs if present (padding is unambiguous), otherwise falls back to the materialized attention mask.

Parameters
  • out_dtype – Output dtype for cumulative lengths. Default: int32.

  • max_segments – Maximum number of segments for JIT compatibility. The output will have shape (max_segments + 1,). Default: 128.

Returns

Tuple of (cu_seqlens_q, cu_seqlens_kv), each with shape (max_segments + 1,). Format: [0, cumsum_seg0, cumsum_seg0+seg1, …] for FlashAttention-style packing.

get_or_compute_segment_ids(per_head: bool = False) tuple[jaxtyping.Int[jaxlib._jax.Array, '...'], jaxtyping.Int[jaxlib._jax.Array, '...']][source]#

Get segment IDs, computing from attention mask if not available.

Parameters

per_head – Forwarded to mask_to_segment_ids when segment IDs are derived from a 4D attention mask. When True, returns per-head segment IDs with shape (batch, heads, seqlen).

Returns

Tuple of (q_segment_ids, kv_segment_ids)

Raises

ValueError – If both attention_mask and segment_ids are None

get_qkv_masks(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen']][source]#

Get separate query mask, key-value mask, and attention mask.

Parameters

dtype – Desired output dtype (default: bool)

Returns

  • q_mask: (batch, qlen) boolean mask for valid query positions

  • kv_mask: (batch, kvlen) boolean mask for valid key-value positions

  • attention_mask: (batch, 1, qlen, kvlen) 4D pairwise attention mask

Return type

Tuple of (q_mask, kv_mask, attention_mask) where

Raises

ValueError – If both attention_mask and segment_ids are None

get_shardings(sequence_parallel: bool = False, *, mesh: Mesh) MaskSharding[source]#

Generate sharding specifications for all mask components.

Creates PartitionSpec objects that define how to distribute the mask tensors across devices in a multi-device setup. Uses the axis names configured in the MaskInfo instance.

Parameters
  • sequence_parallel – Whether to shard along the sequence dimension. If True, sequences are split across devices. Default: False

  • mesh – JAX mesh defining the device grid and axis names

Returns

MaskSharding containing partition specs for all mask components

Raises

ValueError – If configured axis names are not present in the mesh, or if attention_mask is not 4D

Example

>>> from jax.sharding import Mesh
>>> devices = jax.devices()
>>> mesh = Mesh(devices, axis_names=('dp', 'tp'))
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]]))
>>> shardings = mask_info.get_shardings(mesh=mesh)
property is_multi_sequence: bool#

Check if the segment IDs represent multiple sequences (packed format).

This property determines whether the mask contains multiple distinct sequences by examining the segment IDs. In packed/multi-sequence format, different sequences are assigned different segment IDs (0, 1, 2, …), whereas single-sequence format uses only segment ID 0 for valid tokens and -1 for padding.

Returns

JAX boolean array (scalar) indicating if multiple sequences are present. Returns True if max segment ID > 0, False otherwise. Returns False if segment IDs are not available.

Note

This property returns a JAX array (not Python bool) to be JIT-compatible. You can use it directly in JAX conditionals or convert to Python bool with bool(mask_info.is_multi_sequence) outside of JIT contexts.

Examples

>>> # Single sequence: all valid tokens have segment ID 0
>>> q_seg = jnp.array([[0, 0, 0, -1, -1]])
>>> mask_info = MaskInfo.from_segments(q_seg)
>>> mask_info.is_multi_sequence
Array(False, dtype=bool)
>>> # Multiple sequences: tokens have different segment IDs
>>> q_seg = jnp.array([[0, 0, 1, 1, -1]])
>>> mask_info = MaskInfo.from_segments(q_seg)
>>> mask_info.is_multi_sequence
Array(True, dtype=bool)
>>> # JIT-compatible usage
>>> @jax.jit
>>> def process(q_seg):
>>>     info = MaskInfo.from_segments(q_seg)
>>>     return info.is_multi_sequence
is_self_attention() bool[source]#

Check if this represents self-attention (same query and key-value sequences).

Returns

True if query and key-value sequences are identical, False otherwise

property kv_len: int | None#

Get key-value sequence length.

Infers the key-value sequence dimension from either segment IDs or attention mask.

Returns

Key-value sequence length if available, None otherwise

property kv_lens: jax.jaxlib._jax.Array | None#

Get per-segment lengths for keys/values.

For packed sequences with multiple segments (distinct segment IDs), returns the length of each segment. For simple padding masks (all valid tokens have the same segment ID), returns per-batch valid token counts.

Returns

Array with length of each segment/batch.

kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None#
property kv_segment_ids: jax.jaxlib._jax.Array | None#
kvheads_axis_name: tuple[str] | str | None = 'tp'#
materialize_attention_mask(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) MaskInfo[source]#
materialize_segment_ids(per_head: bool = False) MaskInfo[source]#
property q_attention_mask#
property q_len: int | None#

Get query sequence length.

Infers the query sequence dimension from either segment IDs or attention mask.

Returns

Query sequence length if available, None otherwise

property q_lens: jax.jaxlib._jax.Array | None#

Get per-segment lengths for queries.

For packed sequences with multiple segments (distinct segment IDs), returns the length of each segment. For simple padding masks (all valid tokens have the same segment ID), returns per-batch valid token counts.

Returns

Array with length of each segment/batch.

property q_position_ids: Array#

Compute position IDs from the query segment IDs.

Returns per-segment positions that reset at each segment boundary. Padding positions (segment_id == -1) get position -1.

Example

segment_ids = [[-1, -1, 0, 0, 0, 1, 1, -1]] q_position_ids = [[-1, -1, 0, 1, 2, 0, 1, -1]]

q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None#
property q_segment_ids: jax.jaxlib._jax.Array | None#
qheads_axis_name: tuple[str] | str | None = 'tp'#
replace(*, attention_mask=None, q_segment_ids=None, kv_segment_ids=None, cu_seqlens_q=None, cu_seqlens_kv=None, **kw) MaskInfo[source]#

Create a new MaskInfo with specified fields replaced.

This is a convenience method for creating modified copies of MaskInfo instances, using dataclasses.replace(). Only specified fields are updated; others are preserved from the original instance.

Parameters
  • attention_mask – New attention mask array, or None to keep existing

  • q_segment_ids – New query segment IDs, or None to keep existing

  • kv_segment_ids – New key-value segment IDs, or None to keep existing

  • cu_seqlens_q – New cumulative query sequence lengths (batch+1,), or None to keep existing

  • cu_seqlens_kv – New cumulative key/value sequence lengths (batch+1,), or None to keep existing

  • **kw – Additional keyword arguments for other fields: - q_positions: New query positions - kv_positions: New key-value positions - batch_axis_name: New batch axis name(s) - qheads_axis_name: New query heads axis name(s) - kvheads_axis_name: New key-value heads axis name(s) - sequence_axis_name: New sequence axis name(s)

Returns

New MaskInfo instance with specified fields replaced

Example

>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]]))
>>> new_mask_info = mask_info.replace(batch_axis_name="dp")
>>> new_mask_info.batch_axis_name
'dp'
sequence_axis_name: tuple[str] | str | None = 'sp'#
property shape: tuple[int | None, int | None, int | None]#

Get (batch_size, q_len, kv_len) shape tuple.

Convenience property that returns all three dimensions at once.

Returns

Tuple of (batch_size, query_length, key_value_length)

sliding_window_baked_in: bool = False#
to_dtype(dtype: Union[str, type[Any], dtype, SupportsDType]) MaskInfo[source]#

Convert attention mask to specified dtype, returning a new MaskInfo.

Parameters

dtype – Target dtype (e.g., jnp.float32, jnp.bool_)

Returns

New MaskInfo with converted attention mask

token_type_ids_baked_in: bool = False#
tree_flatten()[source]#

Flatten MaskInfo for JAX pytree registration.

This method is required for JAX pytree support, enabling MaskInfo instances to be used seamlessly in JAX transformations (jit, vmap, grad, etc.). It separates the instance into two parts: - Children: Array fields that should be traced and transformed by JAX - Aux data: Static metadata that remains constant across transformations

Returns

  • children: Tuple of (attention_mask, q_segment_ids, kv_segment_ids, cu_seqlens_q, cu_seqlens_kv, q_positions, kv_positions)

  • aux_data: Tuple of (batch_axis_name, qheads_axis_name, kvheads_axis_name, sequence_axis_name)

Return type

Tuple of (children, aux_data) where

Notes

  • This method is automatically called by JAX during pytree operations

  • Users typically don’t need to call this directly

  • The counterpart tree_unflatten() reconstructs the MaskInfo from flattened form

classmethod tree_unflatten(aux_data, children)[source]#

Reconstruct MaskInfo from flattened pytree representation.

This method is the inverse of tree_flatten() and is required for JAX pytree support. It reconstructs a MaskInfo instance from its flattened components after JAX transformations have been applied.

Parameters
  • aux_data – Static metadata tuple containing (batch_axis_name, qheads_axis_name, kvheads_axis_name, sequence_axis_name)

  • children – Traced array tuple containing (attention_mask, q_segment_ids, kv_segment_ids, cu_seqlens_q, cu_seqlens_kv, q_positions, kv_positions)

Returns

Reconstructed MaskInfo instance with the provided arrays and metadata

Notes

  • This method is automatically called by JAX during pytree operations

  • Users typically don’t need to call this directly

  • The method signature must match the output format of tree_flatten()

visualize(block_size: int | tuple[int, int] = 32, batch: int = 0, head: int = 0, fit_in_screen: bool = True, max_rows: int = 32, max_cols: int = 64, charset: Literal['unicode', 'ascii'] = 'unicode', show_segments: bool = True, return_str: bool = False) str | None[source]#

Pretty-print the attention mask as block-aggregated ASCII/Unicode visualization.

Optionally shows aggregated query/key-value segment IDs for each block row/column. Useful for debugging and understanding attention patterns.

Parameters
  • block_size – Size of aggregation blocks. Can be: - int: Square blocks of size (block_size, block_size) - tuple[int, int]: Rectangular blocks (q_block_size, kv_block_size)

  • batch – Batch index to visualize. Default: 0

  • head – Head index to visualize. Default: 0

  • fit_in_screen – If True, downsample to fit within max_rows/max_cols. Default: True

  • max_rows – Maximum number of block rows to display when fit_in_screen=True. Default: 32

  • max_cols – Maximum number of block columns to display when fit_in_screen=True. Default: 64

  • charset – Character set for visualization. Default: “unicode” - “unicode”: Uses box-drawing characters (░░ for partial, ██ for full) - “ascii”: Uses ASCII characters (.. for partial,

  • show_segments – If True, display segment IDs alongside the mask. Default: True

  • return_str – If True, return the visualization as a string instead of printing. Default: False

Returns

If return_str=True, returns the visualization string. Otherwise, prints and returns None

Block encoding:
  • Empty (no attention): ” ” (spaces)

  • Partial (some attention): “░░” (unicode) or “..” (ascii)

  • Full (all attention): “██” (unicode) or “##” (ascii)

Segment ID display:
  • If all tokens in a block share the same segment ID: shows that ID

  • Mixed segments: shown as “??” in header, “MIX” on left

  • Padding: shown as -1 or “PAD”

Notes

  • Not JIT-friendly; runs on host (uses numpy and prints)

  • Segment IDs are taken from self.q_segment_ids/self.kv_segment_ids if present, otherwise computed from the mask (may be per-head if H > 1)

Example

>>> mask_info = MaskInfo.from_segments(jnp.ones((2, 128), dtype=jnp.int32))
>>> mask_info.visualize(block_size=16, batch=0)
class ejkernel.types.mask.MaskSharding(attention_mask: PartitionSpec | None, q_segment_ids: PartitionSpec | None, kv_segment_ids: PartitionSpec | None, cu_seqlens_q: PartitionSpec | None, cu_seqlens_kv: PartitionSpec | None, q_positions: PartitionSpec | None, kv_positions: PartitionSpec | None)[source]#

Bases: NamedTuple

Container for sharding specifications of attention mask components.

Used to specify how different parts of the mask should be partitioned across devices in distributed training scenarios.

attention_mask#

Sharding spec for the 4D attention mask (batch, heads, q, kv)

Type

jax.sharding.PartitionSpec | None

q_segment_ids#

Sharding spec for query segment IDs (batch, qlen)

Type

jax.sharding.PartitionSpec | None

kv_segment_ids#

Sharding spec for key-value segment IDs (batch, kvlen)

Type

jax.sharding.PartitionSpec | None

cu_seqlens_q#

Sharding spec for cumulative query sequence lengths (batch+1,)

Type

jax.sharding.PartitionSpec | None

cu_seqlens_kv#

Sharding spec for cumulative key/value sequence lengths (batch+1,)

Type

jax.sharding.PartitionSpec | None

q_positions#

Sharding spec for query positions (batch, qlen)

Type

jax.sharding.PartitionSpec | None

kv_positions#

Sharding spec for key-value positions (batch, kvlen)

Type

jax.sharding.PartitionSpec | None

attention_mask: jax.sharding.PartitionSpec | None#

Alias for field number 0

cu_seqlens_kv: jax.sharding.PartitionSpec | None#

Alias for field number 4

cu_seqlens_q: jax.sharding.PartitionSpec | None#

Alias for field number 3

kv_positions: jax.sharding.PartitionSpec | None#

Alias for field number 6

kv_segment_ids: jax.sharding.PartitionSpec | None#

Alias for field number 2

q_positions: jax.sharding.PartitionSpec | None#

Alias for field number 5

q_segment_ids: jax.sharding.PartitionSpec | None#

Alias for field number 1

ejkernel.types.mask.attention_mask_to_qkv_cu_seqlens(attention_mask: ~jax.jaxlib._jax.Array, *, reduce_heads: ~typing.Literal['any', 'all', 'first'] = 'any', out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>) tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one']][source]#

Derive Q/KV cumulative sequence lengths from an attention mask.

Supported input shapes:
  • (batch, seq_len): padding mask (self-attention)

  • (batch, q_len, kv_len): pairwise mask

  • (batch, heads, q_len, kv_len): pairwise mask

Notes

For pairwise masks, a token is considered “present” if it participates in at least one unmasked attention edge (after optional head reduction). This matches padding-style outer-product masks but may be ambiguous for arbitrary sparse patterns.

ejkernel.types.mask.cu_seqlens_to_mask(cu_seqlens: ~jaxtyping.Int[jaxlib._jax.Array, 'batch*2'], max_len: int, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#

Convert start/end position pairs into a 2D mask.

Parameters
  • cu_seqlens – (batch * 2,) array with interleaved [start_0, end_0, start_1, end_1, …].

  • max_len – Output sequence length.

  • dtype – Output dtype for the mask.

Returns

(batch, max_len) mask with True/1 for valid tokens (positions start to end-1) and False/0 for positions outside that range.

Example

>>> cu_seqlens = jnp.array([41, 169])  # valid at positions 41-168
>>> mask = cu_seqlens_to_mask(cu_seqlens, max_len=512)
>>> mask.shape
(1, 512)
>>> mask[0, 40], mask[0, 41], mask[0, 168], mask[0, 169]
(False, True, True, False)
ejkernel.types.mask.get_debug_mode() bool[source]#

Check if debug mode is currently enabled.

Returns

True if debug mode is enabled, False otherwise.

Examples

>>> from ejkernel.types.mask import get_debug_mode, set_debug_mode
>>>
>>> get_debug_mode()
False
>>>
>>> set_debug_mode(True)
[MaskInfo Debug] Debug mode enabled
>>>
>>> get_debug_mode()
True
ejkernel.types.mask.mask_to_segment_ids(mask: Array, per_head: bool = False) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#

Convert attention mask to segment IDs (JIT-friendly).

Analyzes the attention mask structure to extract query and key-value segment IDs. Queries with identical attention patterns are grouped into the same segment, and similarly for keys/values. This conversion is useful for optimized attention implementations that can leverage segment structure.

Input shapes:
  • (Q, K): Single 2D mask

  • (B, Q, K): Batched 2D masks

  • (B, H, Q, K): Batched multi-head masks

Parameters
  • mask – Boolean or integer attention mask array

  • per_head – If True and mask is 4D, compute segment IDs separately per head. If False, merge across heads (default behavior). Default: False

Returns

  • If (Q, K): (Q,), (K,)

  • If (B, Q, K): (B, Q), (B, K)

  • If (B, H, Q, K) and per_head=False: (B, Q), (B, K)

  • If (B, H, Q, K) and per_head=True: (B, H, Q), (B, H, K)

Return type

Tuple of (q_segment_ids, kv_segment_ids) with shapes

Notes

  • Padded rows/cols (all-zero) receive segment ID -1

  • Queries/keys with identical attention patterns share the same segment ID

  • This function is JIT-compatible for use in compiled JAX programs

Raises

ValueError – If mask shape is not 2D, 3D, or 4D

Example

>>> mask = jnp.array([[[1, 1, 0], [1, 1, 0], [0, 0, 1]]])
>>> q_ids, kv_ids = mask_to_segment_ids(mask)
>>> q_ids.shape, kv_ids.shape
((1, 3), (1, 3))
ejkernel.types.mask.qkv_cu_seqlens_to_attention_mask(cu_seqlens_q: ~jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#

Convert Q/KV cumulative sequence lengths into a broadcastable 4D outer-product attention mask.

Returns

(batch, 1, max_q_len, max_kv_len) mask.

ejkernel.types.mask.qkv_cu_seqlens_to_qkv_masks(cu_seqlens_q: ~jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#

Convert Q/KV cumulative sequence lengths back into 2D padding masks.

ejkernel.types.mask.qkv_masks_to_cu_seqlens(q_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch q_len'] | jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], kv_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch kv_len'] | jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len'] | None = None, *, out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>) tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch*2'], jaxtyping.Int[jaxlib._jax.Array, 'batch*2']][source]#

Convert per-token Q/KV masks into start/end position pairs.

For each batch element, finds the first and last valid token positions. Returns arrays with interleaved [start_0, end_0, start_1, end_1, …] format.

Parameters
  • q_mask – (batch, q_len) boolean/int mask where non-zero/True means “valid token”.

  • kv_mask – (batch, kv_len) boolean/int mask (defaults to q_mask for self-attention).

  • out_dtype – Output dtype for positions (typically int32).

Returns

(cu_seqlens_q, cu_seqlens_kv), each of shape (batch * 2,). Format: [start_0, end_0, start_1, end_1, …] where valid tokens for batch i are at positions cu_seqlens[2*i] to cu_seqlens[2*i+1]-1.

Example

>>> q_mask = jnp.array([[False, True, True, True, False]])  # valid at 1-3
>>> cu_q, cu_kv = qkv_masks_to_cu_seqlens(q_mask)
>>> cu_q  # [start=1, end=4]
Array([1, 4], dtype=int32)
ejkernel.types.mask.segment_ids_to_mask(segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len'] | tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len']], dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>, return_separate_masks: bool = False) jax.jaxlib._jax.Array | tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#

Converts segment IDs to an attention mask.

This function creates a 2D or 4D attention mask from segment IDs, where tokens in the same segment can attend to each other. It properly handles the padding conventions: - Segment IDs: -1 indicates padding - Attention mask: 0 indicates padding (masked out), 1 indicates valid attention

The function works with both query and key-value segment IDs: - If only query segment IDs are provided: creates a square mask where tokens

with the same segment ID can attend to each other

  • If both query and key-value segment IDs are provided: creates a rectangular mask allowing cross-attention between matching segments

Parameters
  • segment_ids – Segment IDs array. Can be: - 2D: (batch_size, seq_len) for query segment IDs only - Tuple of two 2D arrays: (q_segment_ids, kv_segment_ids)

  • dtype – The output dtype for the mask. Common choices: - jnp.bool_: Boolean mask (True=attend, False=masked) - jnp.float32: Float mask (1.0=attend, 0.0=masked)

  • return_separate_masks – If True, returns (q_mask, kv_mask, attention_mask) tuple where q_mask and kv_mask are 2D masks indicating valid (non-padding) tokens. Default is False, which returns only the attention_mask.

Returns

Attention mask array with shape:
  • (batch_size, 1, seq_len, seq_len) if segment_ids is 2D

  • (batch_size, 1, q_len, kv_len) if segment_ids is a tuple

The mask is always 4D with shape (batch, 1, q, kv) where the second dimension is 1 to allow broadcasting across attention heads.

If return_separate_masks=True:

Tuple of (q_mask, kv_mask, attention_mask) where: - q_mask: (batch_size, q_len) - query mask (True for valid tokens) - kv_mask: (batch_size, kv_len) - key-value mask (True for valid tokens) - attention_mask: (batch_size, 1, q_len, kv_len) - 4D pairwise attention mask

Return type

If return_separate_masks=False (default)

Examples

>>>
>>> segment_ids = jnp.array([
...     [1, 1, 2, 2, -1],
...     [1, 1, 1, -1, -1],
... ])
>>> mask = segment_ids_to_mask(segment_ids)
>>> mask.shape
(2, 1, 5, 5)
>>>
>>>
>>>
>>>
>>> q_mask, kv_mask, attn_mask = segment_ids_to_mask(segment_ids, return_separate_masks=True)
>>> q_mask.shape, kv_mask.shape, attn_mask.shape
((2, 5), (2, 5), (2, 1, 5, 5))
>>> q_mask[0]
>>> kv_mask[0]
>>>
>>> q_segment_ids = jnp.array([[1, 2, 3]])
>>> kv_segment_ids = jnp.array([[1, 1, 2, 2, 3]])
>>> mask = segment_ids_to_mask((q_segment_ids, kv_segment_ids))
>>> mask.shape
(1, 1, 3, 5)
>>>
>>>
>>>
>>>
>>> mask = segment_ids_to_mask(segment_ids, dtype=jnp.float32)
>>>

Notes

  • Segment IDs of -1 are treated as padding

  • Positive segment IDs (1, 2, 3, …) indicate different segments

  • Tokens can only attend within their own segment

  • The output mask is suitable for use with most attention implementations

  • For additive attention bias, convert: bias = (1.0 - mask) * large_negative_value

ejkernel.types.mask.segment_ids_to_qkv_masks(q_segment_ids: ~jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len'] | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#

Converts query and key-value segment IDs to separate Q mask, KV mask, and attention mask.

This is a convenience function that always returns the three masks separately, useful when you need individual control over query and key-value masking.

Parameters
  • q_segment_ids – Query segment IDs of shape (batch_size, q_len). Values of -1 indicate padding.

  • kv_segment_ids – Key-value segment IDs of shape (batch_size, kv_len). If None, uses q_segment_ids (self-attention case). Values of -1 indicate padding.

  • dtype – The output dtype for masks. Common choices: - jnp.bool_: Boolean mask (True=attend, False=masked) - jnp.float32: Float mask (1.0=attend, 0.0=masked)

Returns

  • q_mask: (batch_size, q_len) - Query mask indicating valid (non-padding) query tokens

  • kv_mask: (batch_size, kv_len) - Key-value mask indicating valid (non-padding) KV tokens

  • attention_mask: (batch_size, 1, q_len, kv_len) - 4D pairwise attention mask where tokens in matching segments can attend to each other

Return type

Tuple of (q_mask, kv_mask, attention_mask)

Examples

>>>
>>> segment_ids = jnp.array([[1, 1, 2, -1]])
>>> q_mask, kv_mask, attn_mask = segment_ids_to_qkv_masks(segment_ids)
>>> q_mask.shape, kv_mask.shape, attn_mask.shape
((1, 4), (1, 4), (1, 1, 4, 4))
>>> q_mask[0]
>>> attn_mask[0, 0, 0, 2]
>>>
>>> q_seg = jnp.array([[1, 2]])
>>> kv_seg = jnp.array([[1, 1, 2, 2, -1]])
>>> q_mask, kv_mask, attn_mask = segment_ids_to_qkv_masks(q_seg, kv_seg)
>>> q_mask.shape, kv_mask.shape, attn_mask.shape
((1, 2), (1, 5), (1, 1, 2, 5))
>>> kv_mask[0]
>>> attn_mask[0, 0, 0, :2]
>>>
>>>
>>>
>>>
>>>

Notes

  • This function always returns three separate masks for maximum flexibility

  • Segment IDs of -1 is treated as padding

  • Positive segment IDs (1, 2, 3, …) indicate different segments

  • Tokens can only attend within their own segment

  • For self-attention, q_mask and kv_mask will be identical

ejkernel.types.mask.set_debug_mode(enabled: bool) None[source]#

Enable or disable debug tracing for MaskInfo operations.

When debug mode is enabled, all MaskInfo method calls will be logged to stdout, helping you understand the execution flow and identify performance bottlenecks.

Parameters

enabled – If True, enables debug tracing. If False, disables it.

Examples

>>> from ejkernel.types.mask import set_debug_mode, MaskInfo
>>> import jax.numpy as jnp
>>>
>>> # Enable debug mode
>>> set_debug_mode(True)
>>>
>>> # Now operations will print debug traces
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]]))
[MaskInfo Debug] Calling type.from_segments()
>>>
>>> # Disable debug mode
>>> set_debug_mode(False)