Source code for ejkernel.kernels._triton.gla._interface

# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Gated Linear Attention (GLA) implementation using Triton kernels.

This module provides a specialized implementation of Gated Linear Attention,
a variant of linear attention that incorporates learnable gating mechanisms
to improve model expressiveness while maintaining O(N) time complexity.

GLA extends standard linear attention by applying element-wise gates to the
attention computation, allowing the model to dynamically control information
flow. This is particularly useful for capturing long-range dependencies while
maintaining computational efficiency.

The implementation is built on top of the general recurrent linear attention
kernel, configured specifically for GLA's gating patterns.

Key features:
- O(N) time complexity via recurrent formulation
- Learnable gates (g) for enhanced expressiveness
- Optional decay factors (g_gamma) for temporal dynamics
- Support for variable-length sequences
- GPU-optimized Triton kernels

Example:
    >>> import jax.numpy as jnp
    >>> from ejkernel.kernels._triton.gla import recurrent_gla
    >>>
    >>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
    >>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
    >>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
    >>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
    >>> g = jnp.ones((batch, seq_len, num_heads, head_dim))
    >>>
    >>> output, final_state = recurrent_gla(q, k, v, g=g)
"""

import jaxtyping
from beartype import beartype
from jaxtyping import Array, Float, Int

from ..._registry import Backend, Platform, kernel_registry
from ..recurrent import recurrent


[docs]@kernel_registry.register("gla", Platform.TRITON, Backend.GPU) @jaxtyping.jaxtyped(typechecker=beartype) def recurrent_gla( query: Float[Array, "batch seq_len num_heads qk_head_dim"], key: Float[Array, "batch seq_len num_kv_heads qk_head_dim"], value: Float[Array, "batch seq_len num_kv_heads v_head_dim"], g: Float[Array, "batch seq_len num_heads qk_head_dim"] | None = None, g_gamma: Float[Array, "... num_heads"] | None = None, softmax_scale: float | None = None, initial_state: Float[Array, "... num_heads qk_head_dim v_head_dim"] | None = None, reverse: bool = False, cu_seqlens: Int[Array, "num_seqs_plus_one"] | None = None, ) -> tuple[Float[Array, "batch seq_len num_heads v_head_dim"], Float[Array, "... num_heads qk_head_dim v_head_dim"]]: """ Computes Gated Linear Attention (GLA) in a recurrent, linear-time manner. This function provides a convenient wrapper around the core `recurrent` implementation, tailored for GLA. It processes sequences step-by-step, making it highly efficient for very long sequences and suitable for autoregressive decoding. It supports both standard batch processing and variable-length sequence processing using cumulative sequence lengths (`cu_seqlens`). Args: query: The query tensor. Expected shape is `(batch, seq_len, num_heads, head_dim)` or `(total_tokens, num_heads, head_dim)` if `cu_seqlens` is used. key: The key tensor. Must have the same shape as `q`. value: The value tensor. Must have the same shape as `q`. g: The gate tensor, specific to Gated Linear Attention. If provided, it should have the same shape as `q`. g_gamma: The gate decay factor. softmax_scale: A scaling factor applied to the query before the recurrent computation. If `None`, it defaults to `1 / sqrt(head_dim)`. initial_state: The initial hidden state for the recurrence. Useful for chunked processing of long sequences. reverse: If `True`, the sequence is processed in reverse order. cu_seqlens: Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like `[0, len_seq1, len_seq1+len_seq2, ...]`. If provided, the input tensors `query, key, value, g` are expected to be "packed" with a shape of `(total_tokens, ...)`. Returns: A tuple containing: - o (jax.Array): The output tensor, with the same shape as `q`. - final_state (jax.Array): The final hidden state of the recurrence, which can be used as `initial_state` for a subsequent segment. Raises: ValueError: If `cu_seqlens` is provided and the batch size of `q` is not 1. ValueError: If `cu_seqlens` is provided and the number of initial states does not match the number of sequences. """ if cu_seqlens is not None: if query.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {query.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing." ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if softmax_scale is None: softmax_scale = key.shape[-1] ** -0.5 o, final_state = recurrent( query=query, key=key, value=value, g=g, g_gamma=g_gamma, softmax_scale=softmax_scale, initial_state=initial_state, reverse=reverse, cu_seqlens=cu_seqlens, ) return o, final_state