ejkernel.modules.operations.attention#

Standard multi-head attention module with automatic optimization.

This module implements standard multi-head attention (MHA) with XLA-optimized kernels. It provides a flexible interface supporting various attention patterns including causal masking, dropout, sliding windows, and variable-length sequences.

Unlike FlashAttention which uses tiling for memory efficiency, this implementation leverages XLA’s compiler optimizations for straightforward attention computation.

class ejkernel.modules.operations.attention.Attention[source]#

Bases: Kernel[AttentionConfig, tuple[Array, Array]]

Attention with custom optimization logic.

Supports causal masking, dropout, sliding windows, and variable-length sequences.

Features:
  • Automatic platform/backend selection (XLA Only ;0)

  • Configuration caching for consistent performance

  • Optional autotuning to find optimal implementation

  • Custom gradient support for efficient backpropagation

  • Support for variable-length sequences via cumulative sequence lengths

  • Sliding window attention for local attention patterns

  • Logits soft capping for numerical stability

Example

>>> from ejkernel.modules import Attention, create_default_executor
>>>
>>>
>>> executor = create_default_executor()
>>> attn = Attention()
>>>
>>>
>>> output = executor(attn, query, key, value, causal=True, softmax_scale=0.125)
>>>
>>>
>>> output = executor(
...     attn, query, key, value,...
... )
>>>
>>>
>>> output = executor(attn, query, key, value, sliding_window=(256, 256))
candidate_cfgs(inv: Invocation[AttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

This operation uses XLA primitives directly without tunable block sizes, so autotuning provides no benefit. Returns empty list to skip autotuning.

Parameters

inv – Invocation object with arguments and metadata

Returns

Empty list - no candidates to autotune since XLA handles optimization

Note

XLA’s attention primitive is not parameterized by block sizes, so there are no meaningful configurations to benchmark.

get_impl(cfg: AttentionConfig)[source]#

Get kernel implementation from registry based on configuration.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[AttentionConfig, Array]) AttentionConfig[source]#

Provide default configuration based on invocation context.

Selects optimal block sizes based on sequence length and head dimension.

Parameters

inv – Invocation object with arguments and metadata

Returns

Default configuration with block sizes

run(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: ~typing.Optional[~typing.Callable[[], ~jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, *, cfg: ~ejkernel.modules.operations.configs.AttentionConfig) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']][source]#

Execute flash attention with the given configuration.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • attention_mask – Optional attention mask (legacy, prefer bias)

  • bias – Optional attention bias tensor

  • softmax_scale – Scaling factor for attention scores

  • dropout_prob – Dropout probability for attention weights

  • sliding_window – Window size for local attention

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

  • cfg – Configuration object specifying platform/backend

Returns

Attention output [batch, seq_len_q, num_heads, head_dim]

ejkernel.modules.operations.attention.attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, /, *, mask_info: ejkernel.types.mask.MaskInfo | None = None, init_bias: ~typing.Optional[~typing.Callable[[], ~jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None) Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'][source]#

Execute flash attention with automatic optimization.

Convenience function that uses a default executor and flash attention module.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len_k, num_heads, head_dim]

  • value – Value tensor [batch, seq_len_k, num_heads, head_dim]

  • mask_info – Optional MaskInfo containing attention mask and/or segment IDs

  • bias – Optional attention bias tensor

  • softmax_scale – Scaling factor for attention scores (default: 1/sqrt(head_dim))

  • dropout_prob – Dropout probability for attention weights

  • sliding_window – Window size for local attention (int or (left, right) tuple)

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Attention output with same shape as query

Example

>>>
>>> out = attention(query, key, value)
>>>
>>>
>>> out = attention(query, key, value, dropout_prob=0.1, softmax_scale=0.125)
>>>
>>>
>>> out = attention(query, key, value, platform="xla")