ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd#
PagedAttention TPU kernel.
- class ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.MultiPageAsyncCopyDescriptor(pages_hbm_ref, vmem_buffer, sem, block_tables, page_indices_start_offset, num_pages_to_load, head_index)[source]#
Bases:
objectDescriptor for async copy of multiple K/V pages from HBM.
- ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.paged_flash_attention_kernel(lengths_ref, page_indices_ref, buffer_index_ref, init_flag_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, k_sems, v_sems, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None, sliding_window: int | None = None, program_ids=())[source]#
Pallas kernel for paged attention.
- ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.paged_flash_attention_kernel_inline_seq_dim(lengths_ref, page_indices_ref, buffer_index_ref, init_flag_ref, q_ref, k_pages_hbm_ref, v_pages_hbm_ref, o_ref, m_ref, l_ref, k_vmem_buffer, v_vmem_buffer, k_sems, v_sems, *, batch_size: int, pages_per_compute_block: int, pages_per_sequence: int, mask_value: float, attn_logits_soft_cap: float | None, megacore_mode: str | None, sliding_window: int | None = None)[source]#
- ejkernel.kernels._pallas.tpu.page_attention._pallas_impl_fwd.ref_paged_attention(query: Array, key_cache: Array, value_cache: Array, context_lens: Array, block_tables: Array, *, mask_value: float = -2.381976426469702e+38, attn_logits_soft_cap: float | None = None, sliding_window: int | None = None) Array[source]#
Reference implementation of paged attention for testing.
- Parameters
query – A [batch_size, num_q_heads, head_dim] jax.Array.
key_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
value_cache – A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
context_lens – A i32[batch_size] jax.Array the length of each example.
block_tables – A i32[batch_size, pages_per_sequence] jax.Array.
mask_value – The value used for padding in attention.
attn_logits_soft_cap – The value used for soft capping the attention logits.
sliding_window – If set, only attend to the last sliding_window tokens.
- Returns
The output of attention([batch_size, num_q_heads, head_dim]).