ejkernel.modules.operations.scaled_dot_product_attention#
Scaled Dot Product Attention module with automatic optimization.
This module implements the standard scaled dot-product attention mechanism, which is the fundamental building block of transformer architectures. It computes:
Attention(Q,K,V) = softmax((Q @ K^T) / sqrt(d_k)) @ V
where Q, K, V are the query, key, and value matrices, and d_k is the key dimension.
- This implementation provides:
Automatic platform selection (XLA, Triton, Pallas, CUDA)
Support for various attention patterns (causal, sliding window)
Variable-length sequence handling
Distributed execution via shard_map
Attention biasing and masking
Numerical stability through soft capping
Unlike FlashAttention which uses tiling for memory efficiency, this implementation relies on platform-specific optimizations (e.g., XLA’s attention primitive).
- class ejkernel.modules.operations.scaled_dot_product_attention.ScaledDotProductAttention[source]#
Bases:
Kernel[ScaledDotProductAttentionConfig,Array]ScaledDotProductAttention with custom optimization logic.
Supports causal masking, dropout, sliding windows, and variable-length sequences.
- Features:
Automatic platform/backend selection (XLA Only ;0)
Configuration caching for consistent performance
Optional autotuning to find optimal implementation
Custom gradient support for efficient backpropagation
Support for variable-length sequences via cumulative sequence lengths
Sliding window attention for local attention patterns
Logits soft capping for numerical stability
Example
>>> from ejkernel.modules import ScaledDotProductAttention, create_default_executor >>> >>> >>> executor = create_default_executor() >>> attn = ScaledDotProductAttention() >>> >>> >>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125) >>> >>> >>> output = executor( ... attn, query, key, value,... ... ) >>> >>> >>> output = executor(attn, query, key, value, sliding_window=(256, 256))
- candidate_cfgs(inv: Invocation[ScaledDotProductAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Empty list - no candidates to autotune since XLA handles optimization
Note
XLA’s scaled_dot_product_attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark.
- create_shard_map_wrapper(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, *, mesh: Mesh, in_specs: tuple[jax.sharding.PartitionSpec, ...], out_specs: PartitionSpec, check_vma: bool = False, cfg: ScaledDotProductAttentionConfig, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None)[source]#
Create a shard_map wrapper for distributed ScaledDotProductAttention execution.
Enables efficient distributed execution of attention across multiple devices using JAX’s shard_map functionality. This is particularly useful for model parallelism and handling very large attention computations.
- Parameters
query – Query tensor [batch, seq_len, num_q_heads, head_dim]
key – Key tensor [batch, kv_len, num_kv_heads, head_dim]
value – Value tensor [batch, kv_len, num_kv_heads, head_dim]
attention_mask – Optional attention mask [batch, 1, seq_len, kv_len]
bias – Optional attention bias [batch, num_heads, seq_len, kv_len]
cum_seqlens_q – Cumulative sequence lengths for queries [batch]
cum_seqlens_k – Cumulative sequence lengths for keys [batch]
mesh – JAX mesh defining device topology for distributed execution
in_specs – Partition specifications for each input tensor
out_specs – Partition specification for output tensor
check_vma – Whether to check for virtual memory access patterns
cfg – Configuration object specifying platform/backend
init_bias – Optional callable to initialize bias on-device
softmax_scale – Scaling factor for attention scores
causal – Whether to apply causal masking
sliding_window – Window size for local attention
platform – Optional platform override
- Returns
shard_map function: Callable for distributed execution
call args: Tuple of arguments to pass to the shard_map function
- Return type
Tuple of (shard_map function, call args) where
Note
The shard_map wrapper handles device placement and communication automatically based on the provided mesh and partition specs.
- get_impl(cfg: ScaledDotProductAttentionConfig)[source]#
Get kernel implementation from registry based on configuration.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[ScaledDotProductAttentionConfig, Array]) ScaledDotProductAttentionConfig[source]#
Provide default configuration based on invocation context.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration for platform/backend selection
- run(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: ScaledDotProductAttentionConfig) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute scaled_dot_product_attention with the given configuration.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
attention_mask – Optional scaled_dot_product_attention mask (legacy, prefer bias)
bias – Optional scaled_dot_product_attention bias tensor
softmax_scale – Scaling factor for attention scores
dropout_prob – Dropout probability for attention weights
causal – Whether to apply causal masking
dropout_seed – Random seed for dropout
cum_seqlens_q – Cumulative sequence lengths for variable-length queries
cum_seqlens_k – Cumulative sequence lengths for variable-length keys
sliding_window – Window size for local attention
logits_soft_cap – Optional soft cap value for logits
softmax_aux – Optional attention sink logits
cfg – Configuration object specifying platform/backend
segment_ids – Segment IDs for grouped sequences (TPU-specific)
block_sizes – Block sizes for kernel execution (TPU-specific)
- Returns
ScaledDotProductAttention output [batch, seq_len_q, num_heads, head_dim]
- ejkernel.modules.operations.scaled_dot_product_attention.scaled_dot_product_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, cum_seqlens_q: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, cum_seqlens_k: jaxtyping.Int[jaxlib._jax.Array, 'batch'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, init_bias: Optional[Callable[[], Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, softmax_scale: float | None = None, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, mesh: jax._src.mesh.Mesh | None = None, in_specs: tuple[jax.sharding.PartitionSpec, ...] | None = None, out_specs: jax.sharding.PartitionSpec | None = None) Float[jaxlib._jax.Array, 'batch seq_len_q num_heads head_dim'][source]#
Execute scaled dot product attention with automatic optimization.
Convenience function that uses a default executor and flash attention module.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
mask_info – Optional MaskInfo containing attention mask and/or segment IDs
bias – Optional attention bias tensor
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
dropout_prob – Dropout probability for attention weights
causal – Whether to apply causal masking
dropout_seed – Random seed for dropout
cum_seqlens_q – Cumulative sequence lengths for variable-length queries
cum_seqlens_k – Cumulative sequence lengths for variable-length keys
sliding_window – Window size for local attention (int or (left, right) tuple)
logits_soft_cap – Optional soft cap value for logits
mesh – JAX device mesh for shard_map execution (optional)
in_specs – Input partition specs for shard_map (optional)
out_specs – Output partition spec for shard_map (optional)
- Returns
ScaledDotProductAttention output with same shape as query
Example
>>> >>> out = scaled_dot_product_attention(query, key, value, causal=True) >>> >>> >>> out = scaled_dot_product_attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125) >>> >>> >>> out = scaled_dot_product_attention(query, key, value, cum_seqlens_q=cu_q, cum_seqlens_k=cu_k)