ejkernel.modules.operations.attention#
Standard multi-head attention module with automatic optimization.
This module implements standard multi-head attention (MHA) with XLA-optimized kernels. It provides a flexible interface supporting various attention patterns including causal masking, dropout, sliding windows, and variable-length sequences.
Unlike FlashAttention which uses tiling for memory efficiency, this implementation leverages XLA’s compiler optimizations for straightforward attention computation.
- class ejkernel.modules.operations.attention.Attention[source]#
Bases:
Kernel[AttentionConfig,tuple[Array,Array]]Attention with custom optimization logic.
Supports causal masking, dropout, sliding windows, and variable-length sequences.
- Features:
Automatic platform/backend selection (XLA Only ;0)
Configuration caching for consistent performance
Optional autotuning to find optimal implementation
Custom gradient support for efficient backpropagation
Support for variable-length sequences via cumulative sequence lengths
Sliding window attention for local attention patterns
Logits soft capping for numerical stability
Example
>>> from ejkernel.modules import Attention, create_default_executor >>> >>> >>> executor = create_default_executor() >>> attn = Attention() >>> >>> >>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125) >>> >>> >>> output = executor( ... attn, query, key, value,... ... ) >>> >>> >>> output = executor(attn, query, key, value, sliding_window=(256, 256))
- candidate_cfgs(inv: Invocation[AttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Empty list - no candidates to autotune since XLA handles optimization
Note
XLA’s attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark.
- get_impl(cfg: AttentionConfig)[source]#
Get kernel implementation from registry based on configuration.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[AttentionConfig, Array]) AttentionConfig[source]#
Provide default configuration based on invocation context.
Selects optimal block sizes based on sequence length and head dimension.
- Parameters
inv – Invocation object with arguments and metadata
- Returns
Default configuration with block sizes
- run(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: ~typing.Optional[~typing.Callable[[], ~jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, *, cfg: ~ejkernel.modules.operations.configs.AttentionConfig) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']][source]#
Execute flash attention with the given configuration.
- Parameters
query – Query tensor [batch, seq_len_q, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
attention_mask – Optional attention mask (legacy, prefer bias)
bias – Optional attention bias tensor
softmax_scale – Scaling factor for attention scores
dropout_prob – Dropout probability for attention weights
sliding_window – Window size for local attention
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
cfg – Configuration object specifying platform/backend
- Returns
Attention output [batch, seq_len_q, num_heads, head_dim]
- ejkernel.modules.operations.attention.attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, init_bias: ~typing.Optional[~typing.Callable[[], ~jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'][source]#
Execute flash attention with automatic optimization.
Convenience function that uses a default executor and flash attention module.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len_k, num_heads, head_dim]
value – Value tensor [batch, seq_len_k, num_heads, head_dim]
mask_info – Optional MaskInfo containing attention mask and/or segment IDs
bias – Optional attention bias tensor
softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))
dropout_prob – Dropout probability for attention weights
sliding_window – Window size for local attention (int or (left, right) tuple)
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Attention output with same shape as query
Example
>>> >>> out = attention(query, key, value) >>> >>> >>> out = attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125) >>> >>> >>> out = attention(query, key, value, platform="xla")