ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils#
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_device_name(num_devices: int | None = None)[source]#
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_min_page_size(max_model_len, min_page_size=16)[source]#
Recommended min page size for high-performance kernel.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_tpu_version() int[source]#
Returns the numeric version of the TPU, or -1 if not on TPU.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_tuned_block_sizes(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens, pages_per_seq) tuple[int, int][source]#
Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.