ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd#

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.apply_attention_sinks_block(scores: Float[jaxlib._jax.Array, 'batch q_len heads block_size'], sink_scores: jaxtyping.Float[jaxlib._jax.Array, 'heads num_sinks'] | None = None, num_sinks: int = 0, block_offset: int = 0) Float[jaxlib._jax.Array, 'batch q_len heads block_size'][source]#

Applies attention sink biases to scores for a specific block.

Parameters
  • scores – Attention scores for this block [B, Q, H, block_size]

  • sink_scores – Optional learned biases for sink tokens [H, num_sinks] or [num_sinks]

  • num_sinks – Number of sink tokens

  • block_offset – Offset of this block in the full sequence

Returns

Scores with sink biases applied if this block contains sinks

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.apply_logits_soft_cap(scores: Float[jaxlib._jax.Array, '... seq_len'], logits_soft_cap: float) Float[jaxlib._jax.Array, '... seq_len'][source]#

Applies soft capping to attention logits.

Parameters
  • scores – Attention scores

  • logits_soft_cap – Soft capping value

Returns

Soft-capped scores

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.create_attention_mask(batch_size: int, q_len: int, kv_len: int, sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], sliding_window: tuple[int, int] | None = None, num_sinks: int = 0) Float[jaxlib._jax.Array, 'batch q_len 1 kv_len'][source]#

Creates a comprehensive attention mask with ragged sequences, sliding window, and sinks.

Parameters
  • batch_size – Batch size

  • q_len – Query sequence length

  • kv_len – Key/value sequence length

  • sequence_start – Start indices for each sequence

  • sequence_end – End indices for each sequence

  • sliding_window – Optional (left, right) window size for local attention

  • num_sinks – Number of attention sink tokens (always attendable)

Returns

Boolean mask of shape [batch, q_len, 1, kv_len]

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.flash_attention_block(carry: tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], block_inputs: tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, int], softmax_scale: float, logits_soft_cap: float | None = None, sink_scores: jax.jaxlib._jax.Array | None = None, num_sinks: int = 0) tuple[tuple[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], None][source]#

Enhanced flash attention block with soft cap and sinks.

Parameters
  • carry – Tuple of (output, max_logits, normalizer)

  • block_inputs – Tuple of (queries, keys_block, values_block, mask_block, block_offset)

  • softmax_scale – Scaling factor for attention

  • logits_soft_cap – Optional soft capping value

  • sink_scores – Optional attention sink biases

  • num_sinks – Number of sink tokens

Returns

Updated carry tuple

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.inner_decode_xla(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, block_size: int = 256, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Union[Array, ndarray, bool, number][source]#

Enhanced JIT-compiled XLA implementation of ragged MQA Flash Attention.

Parameters
  • query – Query tensor, optionally with leading singleton dimension

  • key – Key tensor [B, S, H_kv, D]

  • value – Value tensor [B, S, H_kv, D]

  • sequence_start – Sequence start indices

  • sequence_end – Sequence end indices

  • softmax_scale – Scaling factor for attention logits

  • block_size – Block size for attention computation

  • sliding_window – Optional (left, right) window for local attention

  • logits_soft_cap – Optional soft capping for logits (e.g., 50.0)

  • softmax_aux – Optional attention sink biases [H, num_sinks] or [num_sinks] First few tokens become “attention sinks” with learnable biases

Returns

Output tensor with same batch/head structure as query

Examples

output = inner_decode_xla(query, key, value, start, end)

output = inner_decode_xla(

query, key, value, start, end, sliding_window=(128, 0)

)

output = inner_decode_xla(

query, key, value, start, end, logits_soft_cap=50.0

)

sink_biases = jnp.ones(4) * 0.1 output = inner_decode_xla(

query, key, value, start, end, softmax_aux=sink_biases

)

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.ragged_decode_mqa_xla(query: Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, fwd_params: ejkernel.ops.utils.datacarrier.FwdParams | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Float[jaxlib._jax.Array, 'batch num_q_heads head_dim'][source]#

Enhanced XLA-compatible ragged MQA decoding.

Parameters
  • query – Query tensor [B, H_q, D]

  • key – Key tensor [B, S, H_kv, D]

  • value – Value tensor [B, S, H_kv, D]

  • sequence_start – Start indices for each sequence

  • sequence_end – End indices for each sequence

  • softmax_scale – Optional scaling factor

  • block_size – Block size for computation

  • sliding_window – Optional sliding window parameters

  • logits_soft_cap – Optional soft capping for logits

  • softmax_aux – Optional attention sink biases

Returns

Output tensor [B, H_q, D]

ejkernel.kernels._xla.ragged_decode_attention._xla_impl_fwd.ragged_flash_attention_xla(query: Float[jaxlib._jax.Array, 'batch q_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_heads head_dim'], sequence_start: Int[jaxlib._jax.Array, 'batch'], sequence_end: Int[jaxlib._jax.Array, 'batch'], softmax_scale: float | None = None, block_size: int = 256, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, '...'] | None = None) Float[jaxlib._jax.Array, 'batch q_len num_heads head_dim'][source]#

Enhanced XLA-compatible ragged flash attention with sliding window, soft cap, and sinks.

Parameters
  • query – Query tensor [B, Q, H, D]

  • key – Key tensor [B, K, H, D]

  • value – Value tensor [B, K, H, D]

  • sequence_start – Start indices for each sequence

  • sequence_end – End indices for each sequence

  • softmax_scale – Optional scaling factor for attention

  • block_size – Size of blocks for chunked computation

  • sliding_window – Optional (left, right) window for local attention

  • logits_soft_cap – Optional soft capping for logits

  • softmax_aux – Optional attention sink biases [H, num_sinks] or [num_sinks]

Returns

Attention output [B, Q, H, D]