ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd#
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.apply_attention_sinks_block(scores: Float[jaxlib._jax.Array, 'batch q_len heads block_size'], sink_scores: jaxtyping.Float[jaxlib._jax.Array, 'heads num_sinks'] | None = None, num_sinks: int = 0, block_offset: int = 0) Float[jaxlib._jax.Array, 'batch q_len heads block_size'][source]#
Applies attention sink biases to scores for a specific block.
- Parameters
scores – Attention scores for this block [B, Q, H, block_size]
sink_scores – Optional learned biases for sink tokens [H, num_sinks] or [num_sinks]
num_sinks – Number of sink tokens
block_offset – Offset of this block in the full sequence
- Returns
Scores with sink biases applied if this block contains sinks
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.apply_logits_soft_cap(scores: Float[jaxlib._jax.Array, '... seq_len'], logits_soft_cap: float) Float[jaxlib._jax.Array, '... seq_len'][source]#
Applies soft capping to attention logits.
- Parameters
scores – Attention scores
logits_soft_cap – Soft capping value
- Returns
Soft-capped scores
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.create_attention_mask(batch_size: int, q_len: int, kv_len: int, sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], sliding_window: tuple[int, int] | None = None, num_sinks: int = 0) Float[jaxlib._jax.Array, 'batch q_len 1 kv_len'][source]#
Creates a comprehensive attention mask with ragged sequences, sliding window, and sinks.
- Parameters
batch_size – Batch size
q_len – Query sequence length
kv_len – Key/value sequence length
sequence_start – Start indices for each sequence
sequence_end – End indices for each sequence
sliding_window – Optional (left, right) window size for local attention
num_sinks – Number of attention sink tokens (always attendable)
- Returns
Boolean mask of shape [batch, q_len, 1, kv_len]
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.flash_attention_block(carry: tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], block_inputs: tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, int], softmax_scale: float, logits_soft_cap: float | None = None, sink_scores: jax.jaxlib._jax.Array | None = None, num_sinks: int = 0) tuple[tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], None][source]#
Enhanced flash attention block with soft cap and sinks.
- Parameters
carry – Tuple of (output, max_logits, normalizer)
block_inputs – Tuple of (queries, keys_block, values_block, mask_block, block_offset)
softmax_scale – Scaling factor for attention
logits_soft_cap – Optional soft capping value
sink_scores – Optional attention sink biases
num_sinks – Number of sink tokens
- Returns
Updated carry tuple
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.inner_decode_xla(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, block_size: int = 256, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Union[Array, ndarray, bool, number][source]#
Enhanced JIT-compiled XLA implementation of ragged MQA Flash Attention.
- Parameters
query – Query tensor, optionally with leading singleton dimension
key – Key tensor [B, S, H_kv, D]
value – Value tensor [B, S, H_kv, D]
sequence_start – Sequence start indices
sequence_end – Sequence end indices
softmax_scale – Scaling factor for attention logits
block_size – Block size for attention computation
sliding_window – Optional (left, right) window for local attention
logits_soft_cap – Optional soft capping for logits (e.g., 50.0)
softmax_aux – Optional attention sink biases [H, num_sinks] or [num_sinks] First few tokens become “attention sinks” with learnable biases
- Returns
Output tensor with same batch/head structure as query
Examples
output = inner_decode_xla(query, key, value, start, end)
- output = inner_decode_xla(
query, key, value, start, end, sliding_window=(128, 0)
)
- output = inner_decode_xla(
query, key, value, start, end, logits_soft_cap=50.0
)
sink_biases = jnp.ones(4) * 0.1 output = inner_decode_xla(
query, key, value, start, end, softmax_aux=sink_biases
)
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.ragged_decode_mqa_xla(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#
Enhanced XLA-compatible ragged MQA decoding.
- Parameters
query – Query tensor [B, H_q, D]
key – Key tensor [B, S, H_kv, D]
value – Value tensor [B, S, H_kv, D]
sequence_start – Start indices for each sequence
sequence_end – End indices for each sequence
softmax_scale – Optional scaling factor
block_size – Block size for computation
sliding_window – Optional sliding window parameters
logits_soft_cap – Optional soft capping for logits
softmax_aux – Optional attention sink biases
- Returns
Output tensor [B, H_q, D]
- ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.ragged_flash_attention_xla(query: Float[jaxlib._jax.Array, 'batch q_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, block_size: int = 256, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Float[jaxlib._jax.Array, 'batch q_len num_heads head_dim'][source]#
Enhanced XLA-compatible ragged flash attention with sliding window, soft cap, and sinks.
- Parameters
query – Query tensor [B, Q, H, D]
key – Key tensor [B, K, H, D]
value – Value tensor [B, K, H, D]
sequence_start – Start indices for each sequence
sequence_end – End indices for each sequence
softmax_scale – Optional scaling factor for attention
block_size – Size of blocks for chunked computation
sliding_window – Optional (left, right) window for local attention
logits_soft_cap – Optional soft capping for logits
softmax_aux – Optional attention sink biases [H, num_sinks] or [num_sinks]
- Returns
Attention output [B, Q, H, D]