ejkernel.modules.operations.recurrent#

Recurrent Attention module with automatic optimization.

This module implements recurrent-style attention mechanisms that maintain and update hidden states across sequence positions. Unlike standard attention which computes all positions independently, recurrent attention processes sequences sequentially with stateful computation.

Features:
  • Stateful attention with initial_state support

  • Separate gating for queries (g), keys (gk), and values (gv)

  • Layer-wise gating control via g_gamma

  • Bidirectional processing support (forward and reverse)

  • Variable-length sequence handling

This is particularly useful for:
  • Linear-time attention mechanisms

  • Models requiring sequential dependency modeling

  • Architectures with explicit state propagation

  • Efficient inference with incremental state updates

class ejkernel.modules.operations.recurrent.RecurrentAttention[source]#

Bases: Kernel[RecurrentAttentionConfig, Array]

Recurrent Attention with custom optimization logic.

Implements attention with recurrent state updates, enabling linear-time complexity for certain attention patterns. The mechanism maintains a hidden state that is updated at each sequence position.

Features:
  • Stateful computation with hidden state propagation

  • Multiple gating mechanisms (g, gk, gv, g_gamma)

  • Forward and reverse processing modes

  • Support for initial states

  • Variable-length sequence handling via cu_seqlens

  • Multiple platform support (Triton/Pallas/CUDA/XLA)

The gating mechanisms provide fine-grained control:
  • g: Query-level gates

  • gk: Key-level gates

  • gv: Value-level gates

  • g_gamma: Layer-level gates

candidate_cfgs(inv: Invocation[RecurrentAttentionConfig, Array])[source]#

Generate candidate configurations for autotuning.

Parameters

inv – Invocation object containing arguments and metadata

Returns

List of candidate configurations to benchmark during autotuning

Note

Recurrent attention performance is sensitive to state update patterns. Candidates are chosen to balance sequential processing efficiency.

get_impl(cfg: RecurrentAttentionConfig)[source]#

Get kernel implementation from registry.

Parameters

cfg – Configuration specifying platform and backend

Returns

Callable kernel implementation for recurrent attention

Raises

ValueError – If no matching implementation is found

heuristic_cfg(inv: Invocation[RecurrentAttentionConfig, Array]) RecurrentAttentionConfig[source]#

Provide default configuration with block sizes.

Parameters

inv – Invocation object containing arguments and metadata

Returns

Default configuration with conservative block sizes suitable for typical recurrent attention workloads with stateful computation

run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: RecurrentAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute recurrent attention with stateful computation.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • g – Query-level gating tensor [batch, seq_len, num_heads, head_dim]

  • g_gamma – Layer-level gating parameter [batch, num_heads]

  • gk – Key-level gating tensor [batch, seq_len, num_heads, head_dim]

  • gv – Value-level gating tensor [batch, seq_len, num_heads, head_dim]

  • softmax_scale – Optional scaling factor for attention scores

  • initial_state – Initial hidden state [batch, num_heads, head_dim, head_dim]

  • reverse – If True, process sequence in reverse order

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)

  • cfg – Kernel configuration object

Returns

Attention output [batch, seq_len, num_heads, head_dim] If return_state=True: Tuple of (output, final_state) where final_state

is [batch, num_heads, head_dim, head_dim]

Return type

If return_state=False

Note

All gating parameters (g, gk, gv, g_gamma) are optional. When provided, they enable more sophisticated gated recurrent mechanisms.

ejkernel.modules.operations.recurrent.recurrent_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, reverse: bool = False, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RecurrentAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#

Execute recurrent attention with automatic optimization.

Recurrent attention processes sequences with stateful computation, maintaining hidden states across timesteps for efficient sequential processing.

Parameters
  • query – Query tensor [batch, seq_len, num_heads, head_dim]

  • key – Key tensor [batch, seq_len, num_kv_heads, head_dim]

  • value – Value tensor [batch, seq_len, num_kv_heads, head_dim]

  • g – Gating tensor for query [batch, seq_len, num_heads, head_dim]

  • g_gamma – Gating gamma [batch, num_heads]

  • gk – Gating tensor for keys [batch, seq_len, num_heads, head_dim]

  • gv – Gating tensor for values [batch, seq_len, num_heads, head_dim]

  • softmax_scale – Scaling factor for attention

  • initial_state – Initial hidden state

  • reverse – Whether to process sequence in reverse

  • cu_seqlens – Cumulative sequence lengths for variable-length sequences

  • return_state – If True, return tuple (output, final_state) instead of just output

  • platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)

Returns

Attention output with same shape as query If return_state=True: Tuple of (output, final_state)

Return type

If return_state=False

Example

>>>
>>> out = recurrent_attention(query, key, value)
>>>
>>>
>>> out = recurrent_attention(query, key, value, g=gates, gk=key_gates, gv=value_gates)
>>>
>>>
>>> out = recurrent_attention(query, key, value, platform="xla")