ejkernel.kernels._pallas.tpu.page_attention._interface

Contents

ejkernel.kernels._pallas.tpu.page_attention._interface#

ejkernel.kernels._pallas.tpu.page_attention._interface.page_attention(query: Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'], key_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], value_cache: Float[jaxlib._jax.Array, 'num_kv_heads total_num_pages page_size head_dim'], context_lens: Int[jaxlib._jax.Array, 'num_seqs'], block_tables: Int[jaxlib._jax.Array, 'num_seqs max_blocks'], attn_scale: float | None = None, max_context_len: int | None = None, num_splits: int = 0, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, pages_per_compute_block: int | None = None, megacore_mode: str | None = None, inline_seq_dim: bool = True, sliding_window: int | None = None) Float[jaxlib._jax.Array, 'num_seqs num_heads head_dim'][source]#

Paged grouped query attention.

Parameters
  • query – A [batch_size, num_q_heads, head_dim] jax.Array.

  • key_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • value_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • context_lens – A i32[batch_size] jax.Array the length of each example.

  • block_tables – A i32[batch_size, pages_per_sequence] jax.Array. Each entry should be in the range of [0, total_num_pages), indicating where to locate the page in key_cache or value_cache.

  • attn_scale – Attention scaling factor (not used in PALLAS TPU implementation).

  • max_context_len – Maximum context length (not used in PALLAS TPU implementation).

  • num_splits – Number of splits for partitioned attention (not used in PALLAS TPU implementation).

  • mask_value – The value used for padding in attention. By default it is a very negative floating point number.

  • attn_logits_soft_cap – The value used for soft capping the attention logits.

  • pages_per_compute_block – how many pages to be processed in one flash attention block in the pallas kernel.

  • megacore_mode

    if set, enable megacore to parallelize the computation. Must be one of [‘kv_head’, ‘batch’, None]. Caveat: set this only if megacore is enabled, otherwise the kernel may hang. If you are not sure, leave it to None. * None: disable megacore parallelism. * kv_head: megacore parallelism on KV heads; requires number of KV heads

    divisible by 2.

    • batch: megacore parallelism on batch dimension; requires batch divisible by 2.

  • inline_seq_dim – whether to fuse kernel instances along the sequence dim into one kernel.

  • sliding_window – if set, only attend to the last sliding_window tokens. This is useful for models with sliding window attention like Mistral.

Returns

The output of attention([batch_size, num_q_heads, head_dim]).