Source code for ejkernel.kernels._xla.gla._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gated Linear Attention interface using recurrent formulation.

This module provides the public API for GLA implemented via the recurrent
attention mechanism. Wraps the core recurrent implementation with GLA-specific
gating support for linear-time sequence processing.
"""

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry
from ..recurrent import recurrent


[docs]@kernel_registry.register("gla", Platform.XLA, Backend.ANY) @jaxtyping.jaxtyped(typechecker=beartype) def recurrent_gla( query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_kv_heads qk_head_dim"], value: Float[Array, "batch seq_len num_kv_heads v_head_dim"], g: Float[Array, "batch seq_len num_heads qk_head_dim"] | None = None, g_gamma: Float[Array, "... num_heads"] | None = None, softmax_scale: float | None = None, initial_state: Float[Array, "... num_heads qk_head_dim v_head_dim"] | None = None, reverse: bool = False, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> tuple[Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "... num_heads qk_head_dim v_head_dim"]]: """ Computes Gated Linear Attention (GLA) in a recurrent, linear-time manner using JAX/XLA. This function provides a convenient wrapper around the core `recurrent` implementation, tailored for GLA. It processes sequences step-by-step, making it highly efficient for very long sequences and suitable for autoregressive decoding. It supports both standard batch processing and variable-length sequence processing using cumulative sequence lengths (`cu_seqlens`). Args: query: The query tensor. Expected shape is `(batch, seq_len, num_heads, head_dim)` or `(total_tokens, num_heads, head_dim)` if `cu_seqlens` is used. key: The key tensor. Must have the same shape as `q`. value: The value tensor. Must have the same shape as `q`. g: The gate tensor, specific to Gated Linear Attention. If provided, it should have the same shape as `q`. g_gamma: The gate decay factor. softmax_scale: A scaling factor applied to the query before the recurrent computation. If `None`, it defaults to `1 / sqrt(head_dim)`. initial_state: The initial hidden state for the recurrence. Useful for chunked processing of long sequences. reverse: If `True`, the sequence is processed in reverse order. cu_seqlens: Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like `[0, len_seq1, len_seq1+len_seq2, ...]`. If provided, the input tensors `query, key, value, g` are expected to be "packed" with a shape of `(total_tokens, ...)`. Returns: A tuple containing: - o (jax.Array): The output tensor, with the same shape as `q`. - final_state (jax.Array): The final hidden state of the recurrence, which can be used as `initial_state` for a subsequent segment. Raises: ValueError: If `cu_seqlens` is provided and the batch size of `q` is not 1. ValueError: If `cu_seqlens` is provided and the number of initial states does not match the number of sequences. Examples: >>> >>> q = jnp.ones((2, 100, 8, 64)) >>> k = jnp.ones((2, 100, 8, 64)) >>> v = jnp.ones((2, 100, 8, 64)) >>> g = jnp.ones((2, 100, 8, 64)) >>> output, final_state = recurrent_gla(query, key, value, g=g) >>> output.shape (2, 100, 8, 64) >>> >>> q = jnp.ones((150, 8, 64)) >>> k = jnp.ones((150, 8, 64)) >>> v = jnp.ones((150, 8, 64)) >>> g = jnp.ones((150, 8, 64)) >>> cu_seqlens = jnp.array([0, 50, 100, 150]) >>> output, final_state = recurrent_gla(query, key, value, g=g, cu_seqlens=cu_seqlens) >>> output.shape (150, 8, 64) """ if cu_seqlens is not None: if query.shape[0] != 1 and query.ndim == 4: raise ValueError( f"The batch size is expected to be 1 rather than {query.shape[0]} when using `cu_seqlens`. " f"Please flatten variable-length inputs before processing." ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if softmax_scale is None: softmax_scale = key.shape[-1] ** -0.5 o, final_state = recurrent( query=query, key=key, value=value, g=g, g_gamma=g_gamma, softmax_scale=softmax_scale, initial_state=initial_state, reverse=reverse, cu_seqlens=cu_seqlens, ) return o, final_state