ejkernel.kernels._xla.ring_attention._utils#
- ejkernel.kernels._xla.ring_attention._utils.below_or_on_diag(r: int, r_blk_size: int, c: int, c_blk_size: int, causal_block_size: int)[source]#
Checks if the element at (r, c) is below or on the diagonal.
- Parameters
r – Row index.
r_blk_size – Block size of the row.
c – Column index.
c_blk_size – Block size of the column.
causal_block_size – Size of causal blocks.
- Returns
True if the element is below or on the diagonal, False otherwise.