ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils#

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_device_name(num_devices: int | None = None)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_min_page_size(max_model_len, min_page_size=16)[source]#

Recommended min page size for high-performance kernel.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_tpu_version() int[source]#

Returns the numeric version of the TPU, or -1 if not on TPU.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.get_tuned_block_sizes(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens, pages_per_seq) tuple[int, int][source]#

Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.next_power_of_2(x: int)[source]#

Finds the smallest power of 2 >= x using bit manipulation.

Parameters

x – The input number (should be an integer).

Returns

The smallest integer power of 2 that is >= x.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v2._utils.simplify_key(key)[source]#

Simplify the key to reduce the number of combinations.