ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils#
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_device_name(num_devices: int | None = None)[source]#
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_lookup_keys(page_size, q_dtype, kv_dtype, num_q_heads, num_kv_heads, head_dim, max_model_len)[source]#
Get the lookup keys for tuned block sizes.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_lookup_keys_h64(page_size, q_dtype, kv_dtype, num_q_heads, num_kv_heads, head_dim, max_model_len)[source]#
Get the lookup keys for tuned block sizes.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_simplified_raw_key(page_size, q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, max_model_len)[source]#
Get the simplified key.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_simplified_raw_key_h64(page_size, q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, max_model_len)[source]#
Get the simplified key.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_tpu_version() int[source]#
Returns the numeric version of the TPU, or -1 if not on TPU.
- ejkernel.kernels._pallas.tpu.ragged_page_attention_v3._utils.get_tuned_block_sizes(q_dtype, kv_dtype, actual_num_q_heads, actual_num_kv_heads, head_dim, page_size, max_num_tokens, pages_per_seq) tuple[int, int][source]#
Search tuned values for (num_kv_pages_per_blk, num_queries_per_blk).