ejkernel.modules.operations.scaled_dot_product_attention#

Scaled Dot Product Attention module with automatic optimization.

This module implements the standard scaled dot-product attention mechanism, which is the fundamental building block of transformer architectures. It computes:

Attention(Q,K,V) = softmax((Q @ K^T) / sqrt(d_k)) @ V

where Q, K, V are the query, key, and value matrices, and d_k is the key dimension.

This implementation provides:
  • Automatic platform selection (XLA, Triton, Pallas, CUDA)

  • Support for various attention patterns (causal, sliding window)

  • Variable-length sequence handling

  • Distributed execution via shard_map

  • Attention biasing and masking

  • Numerical stability through soft capping

Unlike FlashAttention which uses tiling for memory efficiency, this implementation relies on platform-specific optimizations (e.g., XLA’s attention primitive).

class ejkernel.modules.operations.scaled_dot_product_attention.ScaledDotProductAttention[source]#

Bases: Kernel[ScaledDotProductAttentionConfig, Array]

ScaledDotProductAttention with custom optimization logic.

Supports causal masking, dropout, sliding windows, and variable-length sequences.

Features:
  • Automatic platform/backend selection (XLA Only ;0)

  • Configuration caching for consistent performance

  • Optional autotuning to find optimal implementation

  • Custom gradient support for efficient backpropagation

  • Support for variable-length sequences via cumulative sequence lengths

  • Sliding window attention for local attention patterns

  • Logits soft capping for numerical stability

Example

>>> from ejkernel.modules import ScaledDotProductAttention, create_default_executor
>>>
>>>
>>> executor = create_default_executor()
>>> attn = ScaledDotProductAttention()
>>>
>>>
>>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125)
>>>
>>>
>>> output = executor(
...     attn, query, key, value,...
... )
>>>
>>>
>>> output = executor(attn, query, key, value, sliding_window=(256, 256))
candidate_cfgs(inv: Invocation[ScaledDotProductAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning.

Parameters

inv – Invocation object with arguments and metadata

Returns

Empty list - no candidates to autotune since XLA handles optimization

Note

XLA’s scaled_dot_product_attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark.

create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, *, mesh: Mesh, in_specs: tuple[jax.sharding.PartitionSpec, ...], out_specs: PartitionSpec, check_vma: bool = False, cfg: ScaledDotProductAttentionConfig, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None)[source]#

Create a shard_map wrapper for distributed ScaledDotProductAttention execution.

Enables efficient distributed execution of attention across multiple devices using JAX’s shard_map functionality. This is particularly useful for model parallelism and handling very large attention computations.

Parameters
  • query – Query tensor [batch, seq_len, num_q_heads, head_dim]

  • key – Key tensor [batch, kv_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, kv_len, num_kv_heads, head_dim]

  • attention_mask – Optional attention mask [batch, 1, seq_len, kv_len]

  • bias – Optional attention bias [batch, num_heads, seq_len, kv_len]

  • cum_seqlens_q – Cumulative sequence lengths for queries [batch]

  • cum_seqlens_k – Cumulative sequence lengths for keys [batch]

  • mesh – JAX mesh defining device topology for distributed execution

  • in_specs – Partition specifications for each input tensor

  • out_specs – Partition specification for output tensor

  • check_vma – Whether to check for virtual memory access patterns

  • cfg – Configuration object specifying platform/backend

  • init_bias – Optional callable to initialize bias on-device

  • softmax_scale – Scaling factor for attention scores

  • causal – Whether to apply causal masking

  • sliding_window – Window size for local attention

  • platform – Optional platform override

Returns

  • shard_map function: Callable for distributed execution

  • call args: Tuple of arguments to pass to the shard_map function

Return type

Tuple of (shard_map function, call args) where

Note

The shard_map wrapper handles device placement and communication automatically based on the provided mesh and partition specs.

get_impl(cfg: ScaledDotProductAttentionConfig)[source]#

Get kernel implementation from registry based on configuration.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[ScaledDotProductAttentionConfig, Array]) ScaledDotProductAttentionConfig[source]#

Provide default configuration based on invocation context.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration for platform/backend selection

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ScaledDotProductAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute scaled_dot_product_attention with the given configuration.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • attention_mask – Optional scaled_dot_product_attention mask (legacy, prefer bias)

  • bias – Optional scaled_dot_product_attention bias tensor

  • softmax_scale – Scaling factor for attention scores

  • dropout_prob – Dropout probability for attention weights

  • causal – Whether to apply causal masking

  • dropout_seed – Random seed for dropout

  • cum_seqlens_q – Cumulative sequence lengths for variable-length queries

  • cum_seqlens_k – Cumulative sequence lengths for variable-length keys

  • sliding_window – Window size for local attention

  • logits_soft_cap – Optional soft cap value for logits

  • softmax_aux – Optional attention sink logits

  • cfg – Configuration object specifying platform/backend

  • segment_ids – Segment IDs for grouped sequences (TPU-specific)

  • block_sizes – Block sizes for kernel execution (TPU-specific)

Returns

ScaledDotProductAttention output [batch, seq_len_q, num_heads, head_dim]

ejkernel.modules.operations.scaled_dot_product_attention.scaled_dot_product_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#

Execute scaled dot product attention with automatic optimization.

Convenience function that uses a default executor and flash attention module.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • mask_info – Optional MaskInfo containing attention mask and/or segment IDs

  • bias – Optional attention bias tensor

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • dropout_prob – Dropout probability for attention weights

  • causal – Whether to apply causal masking

  • dropout_seed – Random seed for dropout

  • cum_seqlens_q – Cumulative sequence lengths for variable-length queries

  • cum_seqlens_k – Cumulative sequence lengths for variable-length keys

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • logits_soft_cap – Optional soft cap value for logits

  • mesh – JAX device mesh for shard_map execution (optional)

  • in_specs – Input partition specs for shard_map (optional)

  • out_specs – Output partition spec for shard_map (optional)

Returns

ScaledDotProductAttention output with same shape as query

Example

>>>
>>> out = scaled_dot_product_attention(query, key, value, causal=True)
>>>
>>>
>>> out = scaled_dot_product_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125)
>>>
>>>
>>> out = scaled_dot_product_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k)