ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd#

PagedAttention TPU kernel.

class ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.MultiPageAsyncCopyDescriptor(pages_hbm_ref, vmem_buffer, sem, block_tables, page_indices_start_offset, num_pages_to_load, head_index)[source]#

Bases: object

Descriptor for async copy of multiple K/V pages from HBM.

start()[source]#

Starts the async copies.

wait_and_get_loaded() Array[source]#

Wait async copies and gets the loaded buffer as a jax.Array.

ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.paged_flash_attention_kernel(lengths_ref, page_indices_ref, buffer_index_ref, init_flag_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, k_sems, v_sems, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None, sliding_window: int | None = None, program_ids=())[source]#

Pallas kernel for paged attention.

ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.paged_flash_attention_kernel_inline_seq_dim(lengths_ref, page_indices_ref, buffer_index_ref, init_flag_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, k_sems, v_sems, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None, sliding_window: int | None = None)[source]#
ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.ref_paged_attention(query: Array, key_cache: Array, value_cache: Array, context_lens: Array, block_tables: Array, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Array[source]#

Reference implementation of paged attention for testing.

Parameters
  • query – A [batch_size, num_q_heads, head_dim] jax.Array.

  • key_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • value_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.

  • context_lens – A i32[batch_size] jax.Array the length of each example.

  • block_tables – A i32[batch_size, pages_per_sequence] jax.Array.

  • mask_value – The value used for padding in attention.

  • attn_logits_soft_cap – The value used for soft capping the attention logits.

  • sliding_window – If set, only attend to the last sliding_window tokens.

Returns

The output of attention([batch_size, num_q_heads, head_dim]).