ejkernel.kernels._triton.gla._interface

Contents

ejkernel.kernels._triton.gla._interface#

Gated Linear Attention (GLA) implementation using Triton kernels.

This module provides a specialized implementation of Gated Linear Attention, a variant of linear attention that incorporates learnable gating mechanisms to improve model expressiveness while maintaining O(N) time complexity.

GLA extends standard linear attention by applying element-wise gates to the attention computation, allowing the model to dynamically control information flow. This is particularly useful for capturing long-range dependencies while maintaining computational efficiency.

The implementation is built on top of the general recurrent linear attention kernel, configured specifically for GLA’s gating patterns.

Key features: - O(N) time complexity via recurrent formulation - Learnable gates (g) for enhanced expressiveness - Optional decay factors (g_gamma) for temporal dynamics - Support for variable-length sequences - GPU-optimized Triton kernels

Example

>>> import jax.numpy as jnp
>>> from ejkernel.kernels._triton.gla import recurrent_gla
>>>
>>> batch, seq_len, num_heads, head_dim = 2, 1024, 8, 64
>>> q = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> k = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> v = jnp.ones((batch, seq_len, num_heads, head_dim))
>>> g = jnp.ones((batch, seq_len, num_heads, head_dim))
>>>
>>> output, final_state = recurrent_gla(q, k, v, g=g)
ejkernel.kernels._triton.gla._interface.recurrent_gla(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads qk_head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads v_head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads qk_head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, '... num_heads'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None) tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads v_head_dim'], jaxtyping.Float[jaxlib._jax.Array, '... num_heads qk_head_dim v_head_dim']][source]#

Computes Gated Linear Attention (GLA) in a recurrent, linear-time manner.

This function provides a convenient wrapper around the core recurrent implementation, tailored for GLA. It processes sequences step-by-step, making it highly efficient for very long sequences and suitable for autoregressive decoding.

It supports both standard batch processing and variable-length sequence processing using cumulative sequence lengths (cu_seqlens).

Parameters
  • query – The query tensor. Expected shape is (batch, seq_len, num_heads, head_dim) or (total_tokens, num_heads, head_dim) if cu_seqlens is used.

  • key – The key tensor. Must have the same shape as q.

  • value – The value tensor. Must have the same shape as q.

  • g – The gate tensor, specific to Gated Linear Attention. If provided, it should have the same shape as q.

  • g_gamma – The gate decay factor.

  • softmax_scale – A scaling factor applied to the query before the recurrent computation. If None, it defaults to 1 / sqrt(head_dim).

  • initial_state – The initial hidden state for the recurrence. Useful for chunked processing of long sequences.

  • reverse – If True, the sequence is processed in reverse order.

  • cu_seqlens – Cumulative sequence lengths for variable-length inputs. This is a 1D tensor like [0, len_seq1, len_seq1+len_seq2, …]. If provided, the input tensors query, key, value, g are expected to be “packed” with a shape of (total_tokens, …).

Returns

  • o (jax.Array): The output tensor, with the same shape as q.

  • final_state (jax.Array): The final hidden state of the recurrence, which can be used as initial_state for a subsequent segment.

Return type

A tuple containing

Raises
  • ValueError – If cu_seqlens is provided and the batch size of q is not 1.

  • ValueError – If cu_seqlens is provided and the number of initial states does not match the number of sequences.