ejkernel.modules.operations.gated_linear_attention#

GLA (Gated Linear Attention) module with automatic optimization.

This module implements Gated Linear Attention, an efficient attention mechanism that uses gating to control information flow. GLA combines linear attention properties with learned gates to achieve both efficiency and expressiveness.

The gating mechanism allows the model to dynamically control which information to retain or discard, making it particularly effective for long-range dependencies while maintaining linear complexity in certain configurations.

class ejkernel.modules.operations.gated_linear_attention.GLAttention[source]#

Bases: Kernel[GLAttentionConfig, Array]

Gated Linear Attention with custom optimization logic.

Implements gated linear attention combining the efficiency of linear attention with learnable gating mechanisms for better expressiveness. The gating controls information flow at both the query-key interaction and the state update levels.

Features:
  • Gated attention computation with g (query gates) and g_gamma (layer-wise gates)

  • Support for initial hidden states

  • Bidirectional and reverse sequence processing

  • Variable-length sequence handling via cumulative lengths

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

The dual gating mechanism (g and g_gamma) allows fine-grained control:
  • g: Token-level gates applied to query representations

  • g_gamma: Layer-level gates controlling overall attention strength

candidate_cfgs(inv: Invocation[GLAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

GLA performance depends on the gating mechanism effectiveness and sequence length. Candidates are chosen for typical configurations.

get_impl(cfg: GLAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for gated linear attention

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[GLAttentionConfig, Array]) GLAttentionConfig[source]#

Provide default configuration with block sizes.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with conservative block sizes suitable for typical gated linear attention workloads

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: GLAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute gated linear attention computation.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • g – Token-level gating tensor [batch, seq_len, num_heads, head_dim]

  • g_gamma – Layer-level gating parameter [batch, num_heads]

  • softmax_scale – Optional scaling factor for attention scores

  • initial_state – Initial hidden state [batch, num_heads, head_dim, head_dim]

  • reverse – If True, process sequence in reverse order

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Gated attention output [batch, seq_len, num_heads, head_dim] If return_state=True: Tuple of (output, final_state) where final_state

is [batch, num_heads, head_dim, head_dim]

Return type

If return_state=False

Note

Both g and g_gamma are optional. When provided, they enable more expressive attention patterns through learned gating.

ejkernel.modules.operations.gated_linear_attention.gla_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, reverse: bool = False, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.GLAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute gated linear attention with automatic optimization.

Convenience function that uses a default executor and GLA module.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • g – Gating tensor [batch, seq_len, num_heads, head_dim]

  • g_gamma – Gating gamma [batch, num_heads]

  • softmax_scale – Scaling factor for attention

  • initial_state – Initial state for recurrent computation

  • reverse – Whether to process sequence in reverse

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Attention output with same shape as query If return_state=True: Tuple of (output, final_state)

Return type

If return_state=False

Example

>>>
>>> out = gla_attention(query, key, value)
>>>
>>>
>>> out = gla_attention(query, key, value, g=gates, g_gamma=gamma)
>>>
>>>
>>> out = gla_attention(query, key, value, cu_seqlens=cu_seqs)
>>>
>>>
>>> out = gla_attention(..., platform="triton")