ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils#

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.align_to(x, a)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.cdiv(a, b)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_device_name(num_devices: int | None = None)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_dtype_bitwidth(dtype)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_dtype_packing(dtype)[source]#
ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_lookup_keys(page_size, q_dtype, kv_dtype, num_q_heads, num_kv_heads, head_dim, max_model_len)[source]#

Get the lookup keys for tuned block sizes.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_lookup_keys_h64(page_size, q_dtype, kv_dtype, num_q_heads, num_kv_heads, head_dim, max_model_len)[source]#

Get the lookup keys for tuned block sizes.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_simplified_raw_key(page_size, q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, max_model_len)[source]#

Get the simplified key.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_simplified_raw_key_h64(page_size, q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, max_model_len)[source]#

Get the simplified key.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_tpu_version() int[source]#

Returns the numeric version of the TPU, or -1 if not on TPU.

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_tuned_block_sizes(q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, page_size, max_num_tokens, pages_per_seq) tuple[int, int][source]#

Search tuned values for (num_kv_pages_per_blk, num_queries_per_blk).

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_tuned_block_sizes_h64(q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, page_size, max_num_tokens, pages_per_seq) tuple[int, int][source]#

Search tuned values for (num_kv_pages_per_blk, num_queries_per_blk).

ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.next_power_of_2(x: int)[source]#

Finds the smallest power of 2 >= x using bit manipulation.

Parameters

x – The input number (should be an integer).

Returns

The smallest integer power of 2 that is >= x.