ejkernel.kernels._pallas.tpu.blocksparse_attention._masks#

class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.CausalMask(shape: tuple[int, int], offset: int = 0, shard_count: int = 1)[source]#

Bases: _ComputableMask

Lazy causal mask, prevents the model from attending to future tokens.

offset#

Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.

Type

int

offset: int#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.ChunkedCausalMask(shape: tuple[int, int], chunk_size: int, shard_count: int = 1)[source]#

Bases: _ComputableMask

Lazy chunked causal mask.

Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), … tokens attend to each other but not across chunks. Llama4 models use interleaved chunk attention along with global attention.

chunk_size#

The size of each attention chunk.

Type

int

chunk_size: int#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.FullMask(_shape: tuple[int, int])[source]#

Bases: Mask

Lazy full mask, allows all tokens to attend to all other tokens.

property shape: tuple[int, ...]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LocalMask(shape: tuple[int, int], window_size: tuple[int | None, int | None], offset: int, shard_count: int = 1)[source]#

Bases: _ComputableMask

Lazy local mask, prevents model from attending to tokens outside window.

window_size#

Size of the two sides of the local window (None identifies no limit for the given side).

Type

tuple[int | None, int | None]

offset#

Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.

Type

int

offset: int#
window_size: tuple[int | None, int | None]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LogicalAnd(left: 'Mask', right: 'Mask')[source]#

Bases: Mask

left: Mask#
right: Mask#
property shape: tuple[int, ...]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.LogicalOr(left: 'Mask', right: 'Mask')[source]#

Bases: Mask

left: Mask#
right: Mask#
property shape: tuple[int, ...]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.Mask[source]#

Bases: object

A base class for blocksparse_attention attention masks.

property shape: tuple[int, ...]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.MultiHeadMask(masks: Sequence[Mask])[source]#

Bases: Mask

Lazy multihead mask, combines multiple lazy masks one per head.

masks: Sequence[Mask]#
property shape: tuple[int, ...]#
class ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.NumpyMask(array: ndarray)[source]#

Bases: Mask

A mask backed by a dense numpy array.

array: ndarray#
property shape: tuple[int, ...]#
ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_causal_mask(shape: tuple[int, int], offset: int = 0) ndarray[source]#

Makes a causal attention mask.

Parameters
  • shape – Shape of the 2-dim mask: (q_seq_len, kv_seq_len).

  • offset – Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.

Returns

The causal mask.

ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_chunk_attention_mask(shape: tuple[int, int], chunk_size: int) ndarray[source]#

Makes a chunked causal attention mask.

Parameters
  • shape – The desired shape of the mask (q_seq_len, kv_seq_len).

  • chunk_size – The size of the attention chunks.

Returns

A boolean mask of shape mask_shape where True indicates attention is allowed according to chunked causal rules, and False otherwise.

Raises

ValueError – If chunk_window_size is None or not positive.

ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_local_attention_mask(shape: tuple[int, int], window_size: tuple[int | None, int | None], *, offset: int = 0) ndarray[source]#

Makes a local attention mask.

ejkernel.kernels._pallas.tpu.blocksparse_attention._masks.make_random_mask(shape: tuple[int, int], sparsity: float, seed: int) ndarray[source]#

Makes a random attention mask.