ejkernel.types.mask#
Attention Mask Management for JAX/Flax Models.
This module provides comprehensive tools for creating, manipulating, and converting attention masks in transformer models. It supports various attention patterns and provides efficient conversions between different mask representations.
- Key Components:
MaskInfo: Main dataclass for managing attention masks and segment IDs
Conversion functions: Convert between masks and segment IDs
Attention patterns: Causal, sliding window, chunked, token-type-based
Distributed support: Sharding specifications for multi-device training
Visualization: Debug and understand attention patterns
- Common Usage:
>>> >>> mask_info = MaskInfo.from_segments(segment_ids) >>> >>> >>> mask_info = MaskInfo.from_attention_mask(attention_mask) >>> >>> >>> causal_mask_info = mask_info.apply_causal() >>> >>> >>> bias = mask_info.bias
- Mask Representations:
Attention Mask: 4D boolean/int array (batch, heads, q_len, kv_len) - True/1 = valid attention, False/0 = masked
Segment IDs: 2D int32 arrays (batch, seq_len) - Non-negative = segment membership - -1 = padding tokens
- Debug Mode:
Enable debug tracing to see which functions are being called:
>>> import os >>> os.environ['EJKERNEL_MASK_DEBUG'] = '1' # Enable debug mode >>> # Now all MaskInfo operations will print debug traces >>> mask_info = MaskInfo.from_segments(segment_ids) [MaskInfo Debug] Calling type.from_segments()
Set EJKERNEL_MASK_DEBUG to ‘0’ or remove it to disable debug output.
See also
mask_to_segment_ids(): Convert masks to segment IDs
segment_ids_to_mask(): Convert segment IDs to masks
MaskInfo: Main class for mask management
- class ejkernel.types.mask.MaskInfo(_attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 q k'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 q k'] | None = None, _q_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch q'] | None = None, _kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch k'] | None = None, _cu_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, _cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, causal_mask_baked_in: bool = False, sliding_window_baked_in: bool = False, chunked_mask_baked_in: bool = False, token_type_ids_baked_in: bool = False, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp')[source]#
Bases:
objectContainer for attention mask information with utilities for conversion and manipulation.
This dataclass holds both attention masks and their corresponding segment IDs, along with optional position indices for queries and keys/values. It provides convenient methods for conversion between representations and extracting derived information.
- attention_mask#
The 2D/3D/4D boolean or integer attention mask
- q_segment_ids#
Query segment IDs (batch, qlen) where -1 indicates padding
- kv_segment_ids#
Key-value segment IDs (batch, kvlen) where -1 indicates padding
- cu_seqlens_q#
Cumulative sequence lengths for queries (batch+1,)
- cu_seqlens_kv#
Cumulative sequence lengths for keys/values (batch+1,)
- q_positions#
Query position indices (batch, qlen) for positional embeddings
- Type
jaxtyping.Int[jaxlib._jax.Array, ‘batch qlen’] | None
- kv_positions#
Key-value position indices (batch, kvlen) for positional embeddings
- Type
jaxtyping.Int[jaxlib._jax.Array, ‘batch kvlen’] | None
- causal_mask_baked_in#
Flag indicating if causal masking has been applied
- Type
bool
- sliding_window_baked_in#
Flag indicating if sliding window masking has been applied
- Type
bool
- chunked_mask_baked_in#
Flag indicating if chunked masking has been applied
- Type
bool
- token_type_ids_baked_in#
Flag indicating if token type ID masking has been applied
- Type
bool
- apply_causal(offset: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] = 0) MaskInfo[source]#
Apply causal (autoregressive) masking to the attention pattern.
Restricts attention so that each query position can only attend to key positions at or before its own position (plus an optional offset). The segment IDs are preserved to maintain grouping structure.
- Parameters
offset – Position offset for causal masking. Can be: - int: Scalar offset applied to all batch elements - Array of shape (batch,): Per-batch offsets Default: 0 - offset=0: Standard causal (q_i attends to kv_j where j <= i) - offset>0: Allows attending to future tokens (j <= i + offset) - offset<0: More restrictive causal (j <= i + offset)
- Returns
New MaskInfo with causal constraint applied while preserving segment IDs
- Raises
ValueError – If mask dimensions are unknown
Example
>>> segment_ids = jnp.array([[1, 1, 1, 1]]) >>> mask_info = MaskInfo.from_segments(segment_ids) >>> causal_mask = mask_info.apply_causal() >>> >>> # Per-batch offsets >>> offsets = jnp.array([0, 1, 2]) >>> causal_mask = mask_info.apply_causal(offset=offsets)
- apply_chunked(chunk_size: int, offset: int = 0) MaskInfo[source]#
Apply chunked causal attention and ALWAYS update q/kv segment IDs to chunk IDs.
New segment IDs are the chunk indices + 1 (padding stays -1)
Attention mask becomes: existing_mask AND (same_chunk AND causal)
This makes segment IDs the canonical representation of chunk structure.
Note: segment IDs encode chunk grouping; causal direction still requires positions/rule.
- Parameters
chunk_size – Positive chunk size.
offset – Optional causal offset (default 0).
- Returns
New MaskInfo with updated attention_mask and updated segment IDs.
- apply_kv_lengths(kv_lengths: Int[jaxlib._jax.Array, 'batch'], *, q_len: int | None = None, sliding_window: int | None = None) MaskInfo[source]#
Mask out key/value positions beyond per-example lengths and keep a trailing query window.
- The method expects a 4D attention mask (batch, heads, q_len, kv_len). For each batch item:
KV positions with index >= kv_lengths[b] are masked out.
The query dimension is sliced to the last q_len rows, starting at kv_lengths[b] - q_len.
If sliding_window is provided and smaller than the current KV dimension, only the most recent sliding_window columns are kept.
Segment IDs and position arrays are reused unchanged; only the materialized attention mask is updated.
- Parameters
kv_lengths – Integer array of shape (batch,) with the number of valid KV tokens per batch element. The implementation assumes kv_lengths[b] >= q_len and does not clamp indices.
q_len – Number of query rows to keep. Must be specified and should be <= kv_lengths[b] for all b.
sliding_window – Optional maximum number of KV columns to retain after masking. If None, keeps all remaining KV positions.
- Returns
New MaskInfo whose attention_mask has shape (batch, 1, q_len, effective_kv_len), where effective_kv_len equals kv_len when sliding_window is None and otherwise min(kv_len, sliding_window).
- Raises
ValueError – If the attention mask cannot be materialized.
- apply_sliding_window(sliding_window: int | tuple[int, int], *, offset: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] = 0, mode: Optional[Literal['default', 'decode', 'prefill']] = None, index: int | jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None) MaskInfo[source]#
Apply sliding window attention to the attention pattern.
Restricts attention so that each query position can only attend to key positions within a specified window around its position. This is useful for local attention patterns where distant tokens are masked out.
- Parameters
sliding_window –
Window size specification: - int: Symmetric window (same size left and right) - tuple[int, int]: (left_window, right_window) for asymmetric windows
left_window: How many positions to the left can be attended to
right_window: How many positions to the right can be attended to
offset – Row offset for sliding window calculation. Can be: - int: Scalar offset applied to all batch elements - Array of shape (batch,): Per-batch offsets Default: 0
mode – Attention mode for dynamic slicing: - “default” or None: Standard sliding window without slicing - “decode”: Decode mode - slices KV to window size around current index - “prefill”: Prefill mode - slices to last sliding_window positions
index – Current position index (required for “decode” mode). Can be: - int: Scalar index applied to all batch elements - Array of shape (batch,): Per-batch indices
- Returns
New MaskInfo with sliding window constraint applied while preserving segment IDs
- Raises
ValueError – If mask dimensions are unknown, window size is invalid, or index is missing in decode mode
Example
>>> segment_ids = jnp.array([[1, 1, 1, 1, 1, 1]]) >>> mask_info = MaskInfo.from_segments(segment_ids) >>> >>> # Symmetric window: each position attends to 2 positions left and 2 right >>> windowed_mask = mask_info.apply_sliding_window(2) >>> >>> # Asymmetric window: 3 left, 1 right >>> windowed_mask = mask_info.apply_sliding_window((3, 1)) >>> >>> # Decode mode at position 5 with window size 3 >>> decode_mask = mask_info.apply_sliding_window(3, mode="decode", index=5) >>> >>> # Decode mode with per-batch indices >>> batch_indices = jnp.array([5, 7, 3]) >>> decode_mask = mask_info.apply_sliding_window(3, mode="decode", index=batch_indices) >>> >>> # Prefill mode with window size 4 >>> prefill_mask = mask_info.apply_sliding_window(4, mode="prefill")
- apply_token_type_ids(token_type_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'] | tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len']], *, combine: Literal['union', 'intersect', 'replace'] = 'union', zero_policy: Literal['q', 'kv', 'both', 'none'] = 'q', update_segment_ids: bool | None = None) MaskInfo[source]#
Integrate token_type_ids into the attention pattern.
Builds an equality mask between q and kv token types.
Optionally treats token_type_id == 0 as “disabled” (no token-type matching)
on the query side, kv side, both, or neither (zero_policy). - Combines with the current attention mask by union/intersect/replace. - Optionally updates segment IDs to reflect token types (0 -> -1 padding).
- Parameters
token_type_ids –
self-attn: (batch, q_len)
cross-attn: (q_token_type_ids, kv_token_type_ids)
combine – How to combine with existing mask: - “union”: base_mask OR token_type_mask (matches your old snippet) - “intersect”: base_mask AND token_type_mask - “replace”: token_type_mask only
zero_policy –
“q”: treat q==0 as disabled (no token-type matching for those queries) [matches old code]
”kv”: treat kv==0 as disabled (no matching into those keys/values)
”both”: treat 0 as disabled on both sides
”none”: don’t treat 0 specially
update_segment_ids –
If None: defaults to False for “union” (cannot encode union in seg-ids),
and True for “intersect”/”replace”. - If True: set q/kv segment IDs from token types with 0 -> -1. - If False: keep existing segment IDs.
- Returns
New MaskInfo with updated attention_mask (and optionally updated segment IDs).
- property attention_mask: jax.jaxlib._jax.Array | None#
- property baked_in_masks: dict[str, bool]#
Get a dictionary of all baked-in mask operation flags.
- Returns
‘causal’: Whether causal masking has been applied
’sliding_window’: Whether sliding window masking has been applied
’chunked’: Whether chunked masking has been applied
’token_type_ids’: Whether token type ID masking has been applied
- Return type
Dictionary mapping operation names to their baked-in status
Example
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 1]])) >>> mask_info = mask_info.apply_causal() >>> mask_info.baked_in_masks {'causal': True, 'sliding_window': False, 'chunked': False, 'token_type_ids': False}
- batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp')#
- property batch_size: int | None#
Get batch size from available data.
Infers the batch dimension from either segment IDs or attention mask.
- Returns
Batch size if available, None otherwise
- property bias#
Create attention bias from the mask (convenience property).
Returns an attention bias tensor where valid attention positions are 0.0 and masked positions are set to the minimum float value for the dtype.
- Returns
Attention bias array with dtype float32
- causal_mask_baked_in: bool = False#
- chunked_mask_baked_in: bool = False#
- create_bias(dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]#
Create attention bias from the mask.
Converts the boolean attention mask into an additive bias tensor suitable for attention score computation. Valid positions (mask=True) get 0.0, while masked positions (mask=False) get a large negative value (dtype.min).
- Parameters
dtype – Output dtype for the bias tensor. Default: jnp.float32
- Returns
Valid attention positions: 0.0
Masked positions: jnp.finfo(dtype).min
- Return type
Attention bias array where
Example
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]])) >>> bias = mask_info.create_bias(dtype=jnp.float32) >>>
- static create_chunked_attention_mask(chunk_size: int, q_len: int, kv_len: int | None = None, offset: int = 0, dtype=<class 'jax.numpy.bool'>) Array[source]#
Create a chunked causal attention mask (static method).
Generates a 2D attention mask where attention is restricted to tokens within the same chunk, with causal ordering enforced within chunks.
- Parameters
chunk_size – Size of each chunk (must be positive)
q_len – Query sequence length
kv_len – Key-value sequence length. If None, uses q_len
offset – Causal offset. Default: 0
dtype – Output dtype. Default: jnp.bool_
- Returns
2D attention mask of shape (q_len, kv_len) with chunked causal pattern
- Raises
ValueError – If chunk_size is not positive
Example
>>> mask = MaskInfo.create_chunked_attention_mask( ... chunk_size=4, q_len=8, kv_len=8 ... ) >>> mask.shape (8, 8)
- property cu_seqlens_kv: jax.jaxlib._jax.Array | None#
- property cu_seqlens_q: jax.jaxlib._jax.Array | None#
- classmethod dynamic_init(*, mask_info: ejkernel.types.mask.MaskInfo | None = None, input_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seqlen'] | None = None, inputs_embeds: jaxtyping.Float[jaxlib._jax.Array, 'batch seqlen dim'] | None = None, attention_mask: jaxtyping.Int[jaxlib._jax.Array, 'batch seqlen'] | jaxtyping.Bool[jaxlib._jax.Array, 'batch seqlen'] | None = None) MaskInfo[source]#
Dynamically initialize a MaskInfo from various input sources.
This is a convenience factory method that creates a MaskInfo instance from different types of inputs commonly available in transformer models. It prioritizes existing mask_info, then constructs one from attention_mask or input shapes.
- Parameters
mask_info – Pre-existing MaskInfo to return as-is. If provided, other arguments are ignored.
input_ids – Token IDs array with shape (batch, seq_len). Used to infer shape if mask_info and attention_mask are not provided.
inputs_embeds – Token embeddings array with shape (batch, seq_len, dim). Used to infer shape if mask_info, attention_mask, and input_ids are not provided.
attention_mask – Attention mask array with shape (batch, seq_len). Values should be: - 1/True for valid (non-padding) tokens - 0/False for padding tokens If not provided, creates an all-ones mask (no padding).
- Returns
MaskInfo instance constructed from the provided inputs
- Raises
ValueError – If insufficient information is provided (no valid inputs)
Example
>>> input_ids = jnp.array([[1, 2, 3, 0], [4, 5, 0, 0]]) >>> attn_mask = jnp.array([[1, 1, 1, 0], [1, 1, 0, 0]]) >>> mask_info = MaskInfo.dynamic_init(input_ids=input_ids, attention_mask=attn_mask) >>> mask_info.shape (2, 4, 4)
Notes
This method is useful for model implementations where mask format may vary
Automatically converts 2D attention masks to segment-based representation
Higher-dimensional masks are handled via from_attention_mask()
- classmethod from_attention_mask(attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'], q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, *, mask_is_valid: bool = True, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#
Create MaskInfo from an existing attention mask.
For 2D masks this treats the input as a padding mask (1/True = valid, 0/False = padding), converts it to segment IDs (0 for valid tokens, -1 for padding), and materializes a broadcasted 4D pairwise mask.
For 3D/4D masks, padding-style segment IDs are extracted by checking which Q positions attend to at least one KV position (and vice versa). This gives valid/padding information without attempting to recover full segment structure.
- Parameters
attention_mask – Attention mask array. Supported shapes: - (batch, seqlen): 2D padding mask (token mask) - (batch, qlen, kvlen): 3D batched mask - (batch, heads, qlen, kvlen): 4D multi-head mask Values: True/1 = valid attention, False/0 = masked (unless mask_is_valid=False)
q_positions – Optional query position indices (batch, qlen)
kv_positions – Optional key-value position indices (batch, kvlen)
mask_is_valid – If False, treats True/1 entries as masked-out (disallowed) positions and inverts the mask (PyTorch-style attn_mask semantics).
- Returns
MaskInfo with derived segment IDs, original attention mask, and optional positions
- Raises
ValueError – If attention_mask is not 2D, 3D, or 4D
Example
>>> mask = jnp.array([[[1, 1, 0], [1, 1, 0], [0, 0, 1]]]) >>> mask_info = MaskInfo.from_attention_mask(mask) >>> mask_info.q_segment_ids.shape (1, 3)
- classmethod from_cu_seqlens(cu_seqlens_q: Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#
Create a padding-style MaskInfo from cumulative sequence lengths.
This reconstructs 2D padding masks (valid tokens are a prefix) and stores a compact padding-style segment-id representation (0 for valid tokens, -1 for padding). The pairwise attention mask can be materialized on demand via
attention_mask.
- classmethod from_random(batch_size: int, q_len: int, kv_len: int | None = None, sparsity: float = 0.5, seed: int = 0, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp') MaskInfo[source]#
Create MaskInfo with random attention pattern.
Generates a random binary attention mask with specified sparsity level. Useful for testing, experimentation, and studying sparse attention patterns.
- Parameters
batch_size – Batch size
q_len – Query sequence length
kv_len – Key-value sequence length. If None, uses q_len (self-attention)
sparsity – Fraction of attention positions to mask out (0.0 = full attention, 1.0 = fully masked). Default: 0.5 (50% masked)
seed – Random seed for reproducibility. Default: 0
q_positions – Optional query position indices (batch, qlen)
kv_positions – Optional key-value position indices (batch, kvlen)
- Returns
MaskInfo with random attention pattern and optional positions
Example
>>> >>> mask_info = MaskInfo.from_random( ... batch_size=2, ... q_len=128, ... sparsity=0.7, ... seed=42 ... ) >>> mask_info.attention_mask.shape (2, 1, 128, 128)
>>> >>> mask_info = MaskInfo.from_random( ... batch_size=1, ... q_len=64, ... kv_len=128, ... sparsity=0.5, ... seed=0 ... ) >>> mask_info.attention_mask.shape (1, 1, 64, 128)
- classmethod from_segments(q_segment_ids: Int[jaxlib._jax.Array, 'batch qlen'], kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None, kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None, batch_axis_name: tuple[str] | str | None = ('dp', 'fsdp'), qheads_axis_name: tuple[str] | str | None = 'tp', kvheads_axis_name: tuple[str] | str | None = 'tp', sequence_axis_name: tuple[str] | str | None = 'sp', is_attn_mask: bool = False) MaskInfo[source]#
Create MaskInfo from segment IDs.
Constructs a MaskInfo instance from query and key-value segment IDs, automatically generating the corresponding attention mask. Segment IDs group tokens that can attend to each other (same segment ID = can attend).
- Parameters
q_segment_ids – Query segment IDs of shape (batch, qlen). Values should be: - Non-negative integers: segment membership (0, 1, 2, …) - -1: padding tokens
kv_segment_ids – Key-value segment IDs of shape (batch, kvlen). If None, uses q_segment_ids (self-attention case). Values follow same convention as q_segment_ids.
q_positions – Optional query position indices (batch, qlen) for positional embeddings
kv_positions – Optional key-value position indices (batch, kvlen) for positional embeddings
batch_axis_name – Axis name(s) for batch dimension in distributed sharding. Default: (“dp”, “fsdp”)
qheads_axis_name – Axis name(s) for query heads dimension in distributed sharding. Default: “tp”
kvheads_axis_name – Axis name(s) for key-value heads dimension in distributed sharding. Default: “tp”
sequence_axis_name – Axis name(s) for sequence dimension in distributed sharding. Default: “sp”
- Returns
MaskInfo with segment IDs, computed attention mask, optional positions, and sharding configuration
Example
>>> q_seg = jnp.array([[1, 1, 2, 2, -1]]) >>> mask_info = MaskInfo.from_segments(q_seg) >>> mask_info.attention_mask.shape (1, 1, 5, 5)
- static get_empty_sharding() MaskSharding[source]#
Create an empty MaskSharding with all specs set to None.
Useful as a default or placeholder when no sharding is needed.
- Returns
MaskSharding with all fields set to None
- get_or_compute_attention_mask(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#
Get attention mask, using cached data when available and deriving from segment IDs otherwise.
If a materialized attention mask is already stored, it is returned (with dtype conversion if needed). Otherwise, the mask is constructed from the available segment IDs.
- Parameters
dtype – Desired output dtype (default: bool)
- Returns
Attention mask array
- Raises
ValueError – If both attention_mask and segment_ids are None
- get_or_compute_positions() tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None, jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None][source]#
Get position arrays, computing them if not already available.
Generates position indices for queries and keys/values when not explicitly provided. Position arrays are useful for positional embeddings and rotary position embeddings (RoPE).
- Returns
q_positions: (batch, qlen) position indices for queries, or None if dimensions unknown
kv_positions: (batch, kvlen) position indices for keys/values, or None if dimensions unknown
- Return type
Tuple of (q_positions, kv_positions) where
Example
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]])) >>> q_pos, kv_pos = mask_info.get_or_compute_positions() >>> q_pos.shape (1, 4) >>> kv_pos[0] Array([0, 1, 2, 3], dtype=int32)
- get_or_compute_qkv_cu_seqlens(*, out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>, max_segments: int = 128) tuple[jaxtyping.Int[jaxlib._jax.Array, 'max_segments_plus_1'], jaxtyping.Int[jaxlib._jax.Array, 'max_segments_plus_1']][source]#
Derive (cu_seqlens_q, cu_seqlens_kv) from the available mask representation.
Prefers segment IDs if present (padding is unambiguous), otherwise falls back to the materialized attention mask.
- Parameters
out_dtype – Output dtype for cumulative lengths. Default: int32.
max_segments – Maximum number of segments for JIT compatibility. The output will have shape (max_segments + 1,). Default: 128.
- Returns
Tuple of (cu_seqlens_q, cu_seqlens_kv), each with shape (max_segments + 1,). Format: [0, cumsum_seg0, cumsum_seg0+seg1, …] for FlashAttention-style packing.
- get_or_compute_segment_ids(per_head: bool = False) tuple[jaxtyping.Int[jaxlib._jax.Array, '...'], jaxtyping.Int[jaxlib._jax.Array, '...']][source]#
Get segment IDs, computing from attention mask if not available.
- Parameters
per_head – Forwarded to mask_to_segment_ids when segment IDs are derived from a 4D attention mask. When True, returns per-head segment IDs with shape (batch, heads, seqlen).
- Returns
Tuple of (q_segment_ids, kv_segment_ids)
- Raises
ValueError – If both attention_mask and segment_ids are None
- get_qkv_masks(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jaxtyping.Bool[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen'] | jaxtyping.Int[jaxlib._jax.Array, 'batch nheads_or_1 qlen kvlen']][source]#
Get separate query mask, key-value mask, and attention mask.
- Parameters
dtype – Desired output dtype (default: bool)
- Returns
q_mask: (batch, qlen) boolean mask for valid query positions
kv_mask: (batch, kvlen) boolean mask for valid key-value positions
attention_mask: (batch, 1, qlen, kvlen) 4D pairwise attention mask
- Return type
Tuple of (q_mask, kv_mask, attention_mask) where
- Raises
ValueError – If both attention_mask and segment_ids are None
- get_shardings(sequence_parallel: bool = False, *, mesh: Mesh) MaskSharding[source]#
Generate sharding specifications for all mask components.
Creates PartitionSpec objects that define how to distribute the mask tensors across devices in a multi-device setup. Uses the axis names configured in the MaskInfo instance.
- Parameters
sequence_parallel – Whether to shard along the sequence dimension. If True, sequences are split across devices. Default: False
mesh – JAX mesh defining the device grid and axis names
- Returns
MaskSharding containing partition specs for all mask components
- Raises
ValueError – If configured axis names are not present in the mesh, or if attention_mask is not 4D
Example
>>> from jax.sharding import Mesh >>> devices = jax.devices() >>> mesh = Mesh(devices, axis_names=('dp', 'tp')) >>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]])) >>> shardings = mask_info.get_shardings(mesh=mesh)
- property is_multi_sequence: bool#
Check if the segment IDs represent multiple sequences (packed format).
This property determines whether the mask contains multiple distinct sequences by examining the segment IDs. In packed/multi-sequence format, different sequences are assigned different segment IDs (0, 1, 2, …), whereas single-sequence format uses only segment ID 0 for valid tokens and -1 for padding.
- Returns
JAX boolean array (scalar) indicating if multiple sequences are present. Returns True if max segment ID > 0, False otherwise. Returns False if segment IDs are not available.
Note
This property returns a JAX array (not Python bool) to be JIT-compatible. You can use it directly in JAX conditionals or convert to Python bool with bool(mask_info.is_multi_sequence) outside of JIT contexts.
Examples
>>> # Single sequence: all valid tokens have segment ID 0 >>> q_seg = jnp.array([[0, 0, 0, -1, -1]]) >>> mask_info = MaskInfo.from_segments(q_seg) >>> mask_info.is_multi_sequence Array(False, dtype=bool)
>>> # Multiple sequences: tokens have different segment IDs >>> q_seg = jnp.array([[0, 0, 1, 1, -1]]) >>> mask_info = MaskInfo.from_segments(q_seg) >>> mask_info.is_multi_sequence Array(True, dtype=bool)
>>> # JIT-compatible usage >>> @jax.jit >>> def process(q_seg): >>> info = MaskInfo.from_segments(q_seg) >>> return info.is_multi_sequence
- is_self_attention() bool[source]#
Check if this represents self-attention (same query and key-value sequences).
- Returns
True if query and key-value sequences are identical, False otherwise
- property kv_len: int | None#
Get key-value sequence length.
Infers the key-value sequence dimension from either segment IDs or attention mask.
- Returns
Key-value sequence length if available, None otherwise
- property kv_lens: jax.jaxlib._jax.Array | None#
Get per-segment lengths for keys/values.
For packed sequences with multiple segments (distinct segment IDs), returns the length of each segment. For simple padding masks (all valid tokens have the same segment ID), returns per-batch valid token counts.
- Returns
Array with length of each segment/batch.
- kv_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch kvlen'] | None = None#
- property kv_segment_ids: jax.jaxlib._jax.Array | None#
- kvheads_axis_name: tuple[str] | str | None = 'tp'#
- materialize_attention_mask(dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) MaskInfo[source]#
- property q_attention_mask#
- property q_len: int | None#
Get query sequence length.
Infers the query sequence dimension from either segment IDs or attention mask.
- Returns
Query sequence length if available, None otherwise
- property q_lens: jax.jaxlib._jax.Array | None#
Get per-segment lengths for queries.
For packed sequences with multiple segments (distinct segment IDs), returns the length of each segment. For simple padding masks (all valid tokens have the same segment ID), returns per-batch valid token counts.
- Returns
Array with length of each segment/batch.
- property q_position_ids: Array#
Compute position IDs from the query segment IDs.
Returns per-segment positions that reset at each segment boundary. Padding positions (segment_id == -1) get position -1.
Example
segment_ids = [[-1, -1, 0, 0, 0, 1, 1, -1]] q_position_ids = [[-1, -1, 0, 1, 2, 0, 1, -1]]
- q_positions: jaxtyping.Int[jaxlib._jax.Array, 'batch qlen'] | None = None#
- property q_segment_ids: jax.jaxlib._jax.Array | None#
- qheads_axis_name: tuple[str] | str | None = 'tp'#
- replace(*, attention_mask=None, q_segment_ids=None, kv_segment_ids=None, cu_seqlens_q=None, cu_seqlens_kv=None, **kw) MaskInfo[source]#
Create a new MaskInfo with specified fields replaced.
This is a convenience method for creating modified copies of MaskInfo instances, using dataclasses.replace(). Only specified fields are updated; others are preserved from the original instance.
- Parameters
attention_mask – New attention mask array, or None to keep existing
q_segment_ids – New query segment IDs, or None to keep existing
kv_segment_ids – New key-value segment IDs, or None to keep existing
cu_seqlens_q – New cumulative query sequence lengths (batch+1,), or None to keep existing
cu_seqlens_kv – New cumulative key/value sequence lengths (batch+1,), or None to keep existing
**kw – Additional keyword arguments for other fields: - q_positions: New query positions - kv_positions: New key-value positions - batch_axis_name: New batch axis name(s) - qheads_axis_name: New query heads axis name(s) - kvheads_axis_name: New key-value heads axis name(s) - sequence_axis_name: New sequence axis name(s)
- Returns
New MaskInfo instance with specified fields replaced
Example
>>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]])) >>> new_mask_info = mask_info.replace(batch_axis_name="dp") >>> new_mask_info.batch_axis_name 'dp'
- sequence_axis_name: tuple[str] | str | None = 'sp'#
- property shape: tuple[int | None, int | None, int | None]#
Get (batch_size, q_len, kv_len) shape tuple.
Convenience property that returns all three dimensions at once.
- Returns
Tuple of (batch_size, query_length, key_value_length)
- sliding_window_baked_in: bool = False#
- to_dtype(dtype: Union[str, type[Any], dtype, SupportsDType]) MaskInfo[source]#
Convert attention mask to specified dtype, returning a new MaskInfo.
- Parameters
dtype – Target dtype (e.g., jnp.float32, jnp.bool_)
- Returns
New MaskInfo with converted attention mask
- token_type_ids_baked_in: bool = False#
- tree_flatten()[source]#
Flatten MaskInfo for JAX pytree registration.
This method is required for JAX pytree support, enabling MaskInfo instances to be used seamlessly in JAX transformations (jit, vmap, grad, etc.). It separates the instance into two parts: - Children: Array fields that should be traced and transformed by JAX - Aux data: Static metadata that remains constant across transformations
- Returns
children: Tuple of (attention_mask, q_segment_ids, kv_segment_ids, cu_seqlens_q, cu_seqlens_kv, q_positions, kv_positions)
aux_data: Tuple of (batch_axis_name, qheads_axis_name, kvheads_axis_name, sequence_axis_name)
- Return type
Tuple of (children, aux_data) where
Notes
This method is automatically called by JAX during pytree operations
Users typically don’t need to call this directly
The counterpart tree_unflatten() reconstructs the MaskInfo from flattened form
- classmethod tree_unflatten(aux_data, children)[source]#
Reconstruct MaskInfo from flattened pytree representation.
This method is the inverse of tree_flatten() and is required for JAX pytree support. It reconstructs a MaskInfo instance from its flattened components after JAX transformations have been applied.
- Parameters
aux_data – Static metadata tuple containing (batch_axis_name, qheads_axis_name, kvheads_axis_name, sequence_axis_name)
children – Traced array tuple containing (attention_mask, q_segment_ids, kv_segment_ids, cu_seqlens_q, cu_seqlens_kv, q_positions, kv_positions)
- Returns
Reconstructed MaskInfo instance with the provided arrays and metadata
Notes
This method is automatically called by JAX during pytree operations
Users typically don’t need to call this directly
The method signature must match the output format of tree_flatten()
- visualize(block_size: int | tuple[int, int] = 32, batch: int = 0, head: int = 0, fit_in_screen: bool = True, max_rows: int = 32, max_cols: int = 64, charset: Literal['unicode', 'ascii'] = 'unicode', show_segments: bool = True, return_str: bool = False) str | None[source]#
Pretty-print the attention mask as block-aggregated ASCII/Unicode visualization.
Optionally shows aggregated query/key-value segment IDs for each block row/column. Useful for debugging and understanding attention patterns.
- Parameters
block_size – Size of aggregation blocks. Can be: - int: Square blocks of size (block_size, block_size) - tuple[int, int]: Rectangular blocks (q_block_size, kv_block_size)
batch – Batch index to visualize. Default: 0
head – Head index to visualize. Default: 0
fit_in_screen – If True, downsample to fit within max_rows/max_cols. Default: True
max_rows – Maximum number of block rows to display when fit_in_screen=True. Default: 32
max_cols – Maximum number of block columns to display when fit_in_screen=True. Default: 64
charset – Character set for visualization. Default: “unicode” - “unicode”: Uses box-drawing characters (░░ for partial, ██ for full) - “ascii”: Uses ASCII characters (.. for partial,
show_segments – If True, display segment IDs alongside the mask. Default: True
return_str – If True, return the visualization as a string instead of printing. Default: False
- Returns
If return_str=True, returns the visualization string. Otherwise, prints and returns None
- Block encoding:
Empty (no attention): ” ” (spaces)
Partial (some attention): “░░” (unicode) or “..” (ascii)
Full (all attention): “██” (unicode) or “##” (ascii)
- Segment ID display:
If all tokens in a block share the same segment ID: shows that ID
Mixed segments: shown as “??” in header, “MIX” on left
Padding: shown as -1 or “PAD”
Notes
Not JIT-friendly; runs on host (uses numpy and prints)
Segment IDs are taken from self.q_segment_ids/self.kv_segment_ids if present, otherwise computed from the mask (may be per-head if H > 1)
Example
>>> mask_info = MaskInfo.from_segments(jnp.ones((2, 128), dtype=jnp.int32)) >>> mask_info.visualize(block_size=16, batch=0)
- class ejkernel.types.mask.MaskSharding(attention_mask: PartitionSpec | None, q_segment_ids: PartitionSpec | None, kv_segment_ids: PartitionSpec | None, cu_seqlens_q: PartitionSpec | None, cu_seqlens_kv: PartitionSpec | None, q_positions: PartitionSpec | None, kv_positions: PartitionSpec | None)[source]#
Bases:
NamedTupleContainer for sharding specifications of attention mask components.
Used to specify how different parts of the mask should be partitioned across devices in distributed training scenarios.
- attention_mask#
Sharding spec for the 4D attention mask (batch, heads, q, kv)
- Type
jax.sharding.PartitionSpec | None
- q_segment_ids#
Sharding spec for query segment IDs (batch, qlen)
- Type
jax.sharding.PartitionSpec | None
- kv_segment_ids#
Sharding spec for key-value segment IDs (batch, kvlen)
- Type
jax.sharding.PartitionSpec | None
- cu_seqlens_q#
Sharding spec for cumulative query sequence lengths (batch+1,)
- Type
jax.sharding.PartitionSpec | None
- cu_seqlens_kv#
Sharding spec for cumulative key/value sequence lengths (batch+1,)
- Type
jax.sharding.PartitionSpec | None
- q_positions#
Sharding spec for query positions (batch, qlen)
- Type
jax.sharding.PartitionSpec | None
- kv_positions#
Sharding spec for key-value positions (batch, kvlen)
- Type
jax.sharding.PartitionSpec | None
- attention_mask: jax.sharding.PartitionSpec | None#
Alias for field number 0
- cu_seqlens_kv: jax.sharding.PartitionSpec | None#
Alias for field number 4
- cu_seqlens_q: jax.sharding.PartitionSpec | None#
Alias for field number 3
- kv_positions: jax.sharding.PartitionSpec | None#
Alias for field number 6
- kv_segment_ids: jax.sharding.PartitionSpec | None#
Alias for field number 2
- q_positions: jax.sharding.PartitionSpec | None#
Alias for field number 5
- q_segment_ids: jax.sharding.PartitionSpec | None#
Alias for field number 1
- ejkernel.types.mask.attention_mask_to_qkv_cu_seqlens(attention_mask: ~jax.jaxlib._jax.Array, *, reduce_heads: ~typing.Literal['any', 'all', 'first'] = 'any', out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>) tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one']][source]#
Derive Q/KV cumulative sequence lengths from an attention mask.
- Supported input shapes:
(batch, seq_len): padding mask (self-attention)
(batch, q_len, kv_len): pairwise mask
(batch, heads, q_len, kv_len): pairwise mask
Notes
For pairwise masks, a token is considered “present” if it participates in at least one unmasked attention edge (after optional head reduction). This matches padding-style outer-product masks but may be ambiguous for arbitrary sparse patterns.
- ejkernel.types.mask.cu_seqlens_to_mask(cu_seqlens: ~jaxtyping.Int[jaxlib._jax.Array, 'batch*2'], max_len: int, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#
Convert start/end position pairs into a 2D mask.
- Parameters
cu_seqlens – (batch * 2,) array with interleaved [start_0, end_0, start_1, end_1, …].
max_len – Output sequence length.
dtype – Output dtype for the mask.
- Returns
(batch, max_len) mask with True/1 for valid tokens (positions start to end-1) and False/0 for positions outside that range.
Example
>>> cu_seqlens = jnp.array([41, 169]) # valid at positions 41-168 >>> mask = cu_seqlens_to_mask(cu_seqlens, max_len=512) >>> mask.shape (1, 512) >>> mask[0, 40], mask[0, 41], mask[0, 168], mask[0, 169] (False, True, True, False)
- ejkernel.types.mask.get_debug_mode() bool[source]#
Check if debug mode is currently enabled.
- Returns
True if debug mode is enabled, False otherwise.
Examples
>>> from ejkernel.types.mask import get_debug_mode, set_debug_mode >>> >>> get_debug_mode() False >>> >>> set_debug_mode(True) [MaskInfo Debug] Debug mode enabled >>> >>> get_debug_mode() True
- ejkernel.types.mask.mask_to_segment_ids(mask: Array, per_head: bool = False) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#
Convert attention mask to segment IDs (JIT-friendly).
Analyzes the attention mask structure to extract query and key-value segment IDs. Queries with identical attention patterns are grouped into the same segment, and similarly for keys/values. This conversion is useful for optimized attention implementations that can leverage segment structure.
- Input shapes:
(Q, K): Single 2D mask
(B, Q, K): Batched 2D masks
(B, H, Q, K): Batched multi-head masks
- Parameters
mask – Boolean or integer attention mask array
per_head – If True and mask is 4D, compute segment IDs separately per head. If False, merge across heads (default behavior). Default: False
- Returns
If (Q, K): (Q,), (K,)
If (B, Q, K): (B, Q), (B, K)
If (B, H, Q, K) and per_head=False: (B, Q), (B, K)
If (B, H, Q, K) and per_head=True: (B, H, Q), (B, H, K)
- Return type
Tuple of (q_segment_ids, kv_segment_ids) with shapes
Notes
Padded rows/cols (all-zero) receive segment ID -1
Queries/keys with identical attention patterns share the same segment ID
This function is JIT-compatible for use in compiled JAX programs
- Raises
ValueError – If mask shape is not 2D, 3D, or 4D
Example
>>> mask = jnp.array([[[1, 1, 0], [1, 1, 0], [0, 0, 1]]]) >>> q_ids, kv_ids = mask_to_segment_ids(mask) >>> q_ids.shape, kv_ids.shape ((1, 3), (1, 3))
- ejkernel.types.mask.qkv_cu_seqlens_to_attention_mask(cu_seqlens_q: ~jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) Array[source]#
Convert Q/KV cumulative sequence lengths into a broadcastable 4D outer-product attention mask.
- Returns
(batch, 1, max_q_len, max_kv_len) mask.
- ejkernel.types.mask.qkv_cu_seqlens_to_qkv_masks(cu_seqlens_q: ~jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'], *, max_q_len: int, cu_seqlens_kv: jaxtyping.Int[jaxlib._jax.Array, 'batch_plus_one'] | None = None, max_kv_len: int | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#
Convert Q/KV cumulative sequence lengths back into 2D padding masks.
- ejkernel.types.mask.qkv_masks_to_cu_seqlens(q_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch q_len'] | jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], kv_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch kv_len'] | jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len'] | None = None, *, out_dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.int32'>) tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch*2'], jaxtyping.Int[jaxlib._jax.Array, 'batch*2']][source]#
Convert per-token Q/KV masks into start/end position pairs.
For each batch element, finds the first and last valid token positions. Returns arrays with interleaved [start_0, end_0, start_1, end_1, …] format.
- Parameters
q_mask – (batch, q_len) boolean/int mask where non-zero/True means “valid token”.
kv_mask – (batch, kv_len) boolean/int mask (defaults to q_mask for self-attention).
out_dtype – Output dtype for positions (typically int32).
- Returns
(cu_seqlens_q, cu_seqlens_kv), each of shape (batch * 2,). Format: [start_0, end_0, start_1, end_1, …] where valid tokens for batch i are at positions cu_seqlens[2*i] to cu_seqlens[2*i+1]-1.
Example
>>> q_mask = jnp.array([[False, True, True, True, False]]) # valid at 1-3 >>> cu_q, cu_kv = qkv_masks_to_cu_seqlens(q_mask) >>> cu_q # [start=1, end=4] Array([1, 4], dtype=int32)
- ejkernel.types.mask.segment_ids_to_mask(segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch seq_len'] | tuple[jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len']], dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>, return_separate_masks: bool = False) jax.jaxlib._jax.Array | tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#
Converts segment IDs to an attention mask.
This function creates a 2D or 4D attention mask from segment IDs, where tokens in the same segment can attend to each other. It properly handles the padding conventions: - Segment IDs: -1 indicates padding - Attention mask: 0 indicates padding (masked out), 1 indicates valid attention
The function works with both query and key-value segment IDs: - If only query segment IDs are provided: creates a square mask where tokens
with the same segment ID can attend to each other
If both query and key-value segment IDs are provided: creates a rectangular mask allowing cross-attention between matching segments
- Parameters
segment_ids – Segment IDs array. Can be: - 2D: (batch_size, seq_len) for query segment IDs only - Tuple of two 2D arrays: (q_segment_ids, kv_segment_ids)
dtype – The output dtype for the mask. Common choices: - jnp.bool_: Boolean mask (True=attend, False=masked) - jnp.float32: Float mask (1.0=attend, 0.0=masked)
return_separate_masks – If True, returns (q_mask, kv_mask, attention_mask) tuple where q_mask and kv_mask are 2D masks indicating valid (non-padding) tokens. Default is False, which returns only the attention_mask.
- Returns
- Attention mask array with shape:
(batch_size, 1, seq_len, seq_len) if segment_ids is 2D
(batch_size, 1, q_len, kv_len) if segment_ids is a tuple
The mask is always 4D with shape (batch, 1, q, kv) where the second dimension is 1 to allow broadcasting across attention heads.
- If return_separate_masks=True:
Tuple of (q_mask, kv_mask, attention_mask) where: - q_mask: (batch_size, q_len) - query mask (True for valid tokens) - kv_mask: (batch_size, kv_len) - key-value mask (True for valid tokens) - attention_mask: (batch_size, 1, q_len, kv_len) - 4D pairwise attention mask
- Return type
If return_separate_masks=False (default)
Examples
>>> >>> segment_ids = jnp.array([ ... [1, 1, 2, 2, -1], ... [1, 1, 1, -1, -1], ... ]) >>> mask = segment_ids_to_mask(segment_ids) >>> mask.shape (2, 1, 5, 5) >>> >>> >>>
>>> >>> q_mask, kv_mask, attn_mask = segment_ids_to_mask(segment_ids, return_separate_masks=True) >>> q_mask.shape, kv_mask.shape, attn_mask.shape ((2, 5), (2, 5), (2, 1, 5, 5)) >>> q_mask[0] >>> kv_mask[0]
>>> >>> q_segment_ids = jnp.array([[1, 2, 3]]) >>> kv_segment_ids = jnp.array([[1, 1, 2, 2, 3]]) >>> mask = segment_ids_to_mask((q_segment_ids, kv_segment_ids)) >>> mask.shape (1, 1, 3, 5) >>> >>> >>>
>>> >>> mask = segment_ids_to_mask(segment_ids, dtype=jnp.float32) >>>
Notes
Segment IDs of -1 are treated as padding
Positive segment IDs (1, 2, 3, …) indicate different segments
Tokens can only attend within their own segment
The output mask is suitable for use with most attention implementations
For additive attention bias, convert: bias = (1.0 - mask) * large_negative_value
- ejkernel.types.mask.segment_ids_to_qkv_masks(q_segment_ids: ~jaxtyping.Int[jaxlib._jax.Array, 'batch q_len'], kv_segment_ids: jaxtyping.Int[jaxlib._jax.Array, 'batch kv_len'] | None = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.bool'>) tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array][source]#
Converts query and key-value segment IDs to separate Q mask, KV mask, and attention mask.
This is a convenience function that always returns the three masks separately, useful when you need individual control over query and key-value masking.
- Parameters
q_segment_ids – Query segment IDs of shape (batch_size, q_len). Values of -1 indicate padding.
kv_segment_ids – Key-value segment IDs of shape (batch_size, kv_len). If None, uses q_segment_ids (self-attention case). Values of -1 indicate padding.
dtype – The output dtype for masks. Common choices: - jnp.bool_: Boolean mask (True=attend, False=masked) - jnp.float32: Float mask (1.0=attend, 0.0=masked)
- Returns
q_mask: (batch_size, q_len) - Query mask indicating valid (non-padding) query tokens
kv_mask: (batch_size, kv_len) - Key-value mask indicating valid (non-padding) KV tokens
attention_mask: (batch_size, 1, q_len, kv_len) - 4D pairwise attention mask where tokens in matching segments can attend to each other
- Return type
Tuple of (q_mask, kv_mask, attention_mask)
Examples
>>> >>> segment_ids = jnp.array([[1, 1, 2, -1]]) >>> q_mask, kv_mask, attn_mask = segment_ids_to_qkv_masks(segment_ids) >>> q_mask.shape, kv_mask.shape, attn_mask.shape ((1, 4), (1, 4), (1, 1, 4, 4)) >>> q_mask[0] >>> attn_mask[0, 0, 0, 2]
>>> >>> q_seg = jnp.array([[1, 2]]) >>> kv_seg = jnp.array([[1, 1, 2, 2, -1]]) >>> q_mask, kv_mask, attn_mask = segment_ids_to_qkv_masks(q_seg, kv_seg) >>> q_mask.shape, kv_mask.shape, attn_mask.shape ((1, 2), (1, 5), (1, 1, 2, 5)) >>> kv_mask[0] >>> attn_mask[0, 0, 0, :2]
>>> >>> >>> >>> >>>
Notes
This function always returns three separate masks for maximum flexibility
Segment IDs of -1 is treated as padding
Positive segment IDs (1, 2, 3, …) indicate different segments
Tokens can only attend within their own segment
For self-attention, q_mask and kv_mask will be identical
- ejkernel.types.mask.set_debug_mode(enabled: bool) None[source]#
Enable or disable debug tracing for MaskInfo operations.
When debug mode is enabled, all MaskInfo method calls will be logged to stdout, helping you understand the execution flow and identify performance bottlenecks.
- Parameters
enabled – If True, enables debug tracing. If False, disables it.
Examples
>>> from ejkernel.types.mask import set_debug_mode, MaskInfo >>> import jax.numpy as jnp >>> >>> # Enable debug mode >>> set_debug_mode(True) >>> >>> # Now operations will print debug traces >>> mask_info = MaskInfo.from_segments(jnp.array([[1, 1, 2, 2]])) [MaskInfo Debug] Calling type.from_segments() >>> >>> # Disable debug mode >>> set_debug_mode(False)