ejkernel.kernels._pallas.tpu.blocksparse_attention._masks#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.CausalMask(shape: tuple[int, int], offset: int = 0, shard_count: int = 1)[source]#
Bases:
_ComputableMaskLazy causal mask, prevents the model from attending to future tokens.
- offset#
Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.
- Type
int
- offset: int#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.ChunkedCausalMask(shape: tuple[int, int], chunk_size: int, shard_count: int = 1)[source]#
Bases:
_ComputableMaskLazy chunked causal mask.
Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), … tokens attend to each other but not across chunks. Llama4 models use interleaved chunk attention along with global attention.
- chunk_size#
The size of each attention chunk.
- Type
int
- chunk_size: int#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.FullMask(_shape: tuple[int, int])[source]#
Bases:
MaskLazy full mask, allows all tokens to attend to all other tokens.
- property shape: tuple[int, ...]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LocalMask(shape: tuple[int, int], window_size: tuple[int | None, int | None], offset: int, shard_count: int = 1)[source]#
Bases:
_ComputableMaskLazy local mask, prevents model from attending to tokens outside window.
- window_size#
Size of the two sides of the local window (None identifies no limit for the given side).
- Type
tuple[int | None, int | None]
- offset#
Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.
- Type
int
- offset: int#
- window_size: tuple[int | None, int | None]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LogicalAnd(left: 'Mask', right: 'Mask')[source]#
Bases:
Mask- property shape: tuple[int, ...]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LogicalOr(left: 'Mask', right: 'Mask')[source]#
Bases:
Mask- property shape: tuple[int, ...]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.Mask[source]#
Bases:
objectA base class for blocksparse_attention attention masks.
- property shape: tuple[int, ...]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.MultiHeadMask(masks: Sequence[Mask])[source]#
Bases:
MaskLazy multihead mask, combines multiple lazy masks one per head.
- property shape: tuple[int, ...]#
- class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.NumpyMask(array: ndarray)[source]#
Bases:
MaskA mask backed by a dense numpy array.
- property shape: tuple[int, ...]#
- ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_causal_mask(shape: tuple[int, int], offset: int = 0) ndarray[source]#
Makes a causal attention mask.
- Parameters
shape – Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
offset – Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.
- Returns
The causal mask.
- ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_chunk_attention_mask(shape: tuple[int, int], chunk_size: int) ndarray[source]#
Makes a chunked causal attention mask.
- Parameters
shape – The desired shape of the mask (q_seq_len, kv_seq_len).
chunk_size – The size of the attention chunks.
- Returns
A boolean mask of shape mask_shape where True indicates attention is allowed according to chunked causal rules, and False otherwise.
- Raises
ValueError – If chunk_window_size is None or not positive.