ejkernel.kernels._xla.attention._interface#
Standard attention interface using pure JAX/XLA operations.
This module provides the public API for standard multi-head attention implemented using native JAX operations, suitable for fallback computation when specialized kernels are unavailable.
- ejkernel.kernels._xla.attention._interface.attention(query: ~jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads head_dim'], key: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads head_dim'], value: ~jaxtyping.Float[jaxlib._jax.Array, 'batch kv_len num_kv_heads vhead_dim'], attention_mask: jaxtyping.Bool[jaxlib._jax.Array, 'batch num_heads_or_1 seq_len kv_len'] | None = None, bias: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: collections.abc.Callable[[], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']] | None = None, deterministic: bool = True, dropout_rng: ~typing.Optional[~typing.Union[~jaxtyping.Key[jaxlib._jax.Array, ''], ~jaxtyping.UInt32[jaxlib._jax.Array, '2']]] = None, softmax_aux: jaxtyping.Float[jaxlib._jax.Array, 'num_heads num_sinks'] | jaxtyping.Float[jaxlib._jax.Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = <class 'jax.numpy.bfloat16'>, softmax_dtype: ~typing.Optional[~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType]] = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_q_heads vhead_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads seq_len kv_len']][source]#
Computes multi-head attention using standard JAX operations.
Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/attention_mask, softmax (potentially in float32), and optional dropout.
- Parameters
query – Query tensor with shape [batch, seq_len, num_q_heads, head_dim]. The main input sequence to attend from.
key – Key tensor with shape [batch, kv_len, num_kv_heads, head_dim]. Keys for attention computation. May have fewer heads than queries (GQA/MQA).
value – Value tensor with shape [batch, kv_len, num_kv_heads, head_dim]. Values to aggregate based on attention weights.
attention_mask – Optional boolean attention_mask with shape [batch, 1, seq_len, kv_len]. True values indicate positions to attend to, False positions are masked. Used if bias is not provided.
bias – Optional attention bias with shape [batch, num_heads, seq_len, kv_len]. Additive bias applied to attention scores before softmax. Takes precedence over attention_mask.
init_bias – Optional callable that returns bias tensor. Used to lazily initialize bias if both attention_mask and bias are None.
deterministic – If True, disables dropout (default). If False, applies dropout.
dropout_rng – JAX PRNG key for dropout. Required when deterministic=False and dropout_prob > 0 in metadata.
softmax_aux – Optional auxiliary tensor for softmax computation.
softmax_scale – Optional float for scaling attention scores. If None, uses 1/sqrt(head_dim).
logits_soft_cap – Optional float for capping attention logits using tanh. When specified, applies: logits_soft_cap * tanh(logits / logits_soft_cap). This prevents attention scores from becoming too large.
dtype – Data type for computation. Defaults to bfloat16.
softmax_dtype – Data type for softmax computation. Defaults to float32.
dropout_prob – Dropout probability. Only applied when deterministic=False.
sliding_window – Optional sliding window attention constraint. Can be: - int: Symmetric window (same left and right window size) - tuple[int, int]: Asymmetric window (left_window, right_window) - None: No window constraint (full attention) When specified, each query position can only attend to keys within the window.
- Returns
attention_outputs: Float[Array, “batch seq_len num_q_heads head_dim”] The attended representation.
attention_weights: Float[Array, “batch num_heads seq_len kv_len”] | None The attention weights (if return_weights is True in metadata).
- Return type
AttentionOutput containing
- Raises
NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.