ejkernel.kernels._xla.attention._interface

Contents

ejkernel.kernels._xla.attention._interface#

Standard attention interface using pure JAX/XLA operations.

This module provides the public API for standard multi-head attention implemented using native JAX operations, suitable for fallback computation when specialized kernels are unavailable.

ejkernel.kernels._xla.attention._interface.attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads vhead_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: collections.abc.Callable[[], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']] | None = None, deterministic: bool = True, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']][source]#

Computes multi-head attention using standard JAX operations.

Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/attention_mask, softmax (potentially in float32), and optional dropout.

Parameters
  • query – Query tensor with shape [batch, seq_len, num_q_heads, head_dim]. The main input sequence to attend from.

  • key – Key tensor with shape [batch, kv_len, num_kv_heads, head_dim]. Keys for attention computation. May have fewer heads than queries (GQA/MQA).

  • value – Value tensor with shape [batch, kv_len, num_kv_heads, head_dim]. Values to aggregate based on attention weights.

  • attention_mask – Optional boolean attention_mask with shape [batch, 1, seq_len, kv_len]. True values indicate positions to attend to, False positions are masked. Used if bias is not provided.

  • bias – Optional attention bias with shape [batch, num_heads, seq_len, kv_len]. Additive bias applied to attention scores before softmax. Takes precedence over attention_mask.

  • init_bias – Optional callable that returns bias tensor. Used to lazily initialize bias if both attention_mask and bias are None.

  • deterministic – If True, disables dropout (default). If False, applies dropout.

  • dropout_rng – JAX PRNG key for dropout. Required when deterministic=False and dropout_prob > 0 in metadata.

  • softmax_aux – Optional auxiliary tensor for softmax computation.

  • softmax_scale – Optional float for scaling attention scores. If None, uses 1/sqrt(head_dim).

  • logits_soft_cap – Optional float for capping attention logits using tanh. When specified, applies: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large.

  • dtype – Data type for computation. Defaults to bfloat16.

  • softmax_dtype – Data type for softmax computation. Defaults to float32.

  • dropout_prob – Dropout probability. Only applied when deterministic=False.

  • sliding_window – Optional sliding window attention constraint. Can be: - int: Symmetric window (same left and right window size) - tuple[int, int]: Asymmetric window (left_window, right_window) - None: No window constraint (full attention) When specified, each query position can only attend to keys within the window.

Returns

  • attention_outputs: Float[Array, “batch seq_len num_q_heads head_dim”] The attended representation.

  • attention_weights: Float[Array, “batch num_heads seq_len kv_len”] | None The attention weights (if return_weights is True in metadata).

Return type

AttentionOutput containing

Raises

NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.