ejkernel.modules.operations.recurrent#
Recurrent Attention module with automatic optimization.
This module implements recurrent-style attention mechanisms that maintain and update hidden states across sequence positions. Unlike standard attention which computes all positions independently, recurrent attention processes sequences sequentially with stateful computation.
- Features:
Stateful attention with initial_state support
Separate gating for queries (g), keys (gk), and values (gv)
Layer-wise gating control via g_gamma
Bidirectional processing support (forward and reverse)
Variable-length sequence handling
- This is particularly useful for:
Linear-time attention mechanisms
Models requiring sequential dependency modeling
Architectures with explicit state propagation
Efficient inference with incremental state updates
- class ejkernel.modules.operations.recurrent.RecurrentAttention[source]#
Bases:
Kernel[RecurrentAttentionConfig,Array]Recurrent Attention with custom optimization logic.
Implements attention with recurrent state updates, enabling linear-time complexity for certain attention patterns. The mechanism maintains a hidden state that is updated at each sequence position.
- Features:
Stateful computation with hidden state propagation
Multiple gating mechanisms (g, gk, gv, g_gamma)
Forward and reverse processing modes
Support for initial states
Variable-length sequence handling via cu_seqlens
Multiple platform support (Triton/Pallas/CUDA/XLA)
- The gating mechanisms provide fine-grained control:
g: Query-level gates
gk: Key-level gates
gv: Value-level gates
g_gamma: Layer-level gates
- candidate_cfgs(inv: Invocation[RecurrentAttentionConfig, Array])[source]#
Generate candidate configurations for autotuning.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
List of candidate configurations to benchmark during autotuning
Note
Recurrent attention performance is sensitive to state update patterns. Candidates are chosen to balance sequential processing efficiency.
- get_impl(cfg: RecurrentAttentionConfig)[source]#
Get kernel implementation from registry.
- Parameters
cfg – Configuration specifying platform and backend
- Returns
Callable kernel implementation for recurrent attention
- Raises
ValueError – If no matching implementation is found
- heuristic_cfg(inv: Invocation[RecurrentAttentionConfig, Array]) RecurrentAttentionConfig[source]#
Provide default configuration with block sizes.
- Parameters
inv – Invocation object containing arguments and metadata
- Returns
Default configuration with conservative block sizes suitable for typical recurrent attention workloads with stateful computation
- run(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, softmax_scale: float | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, reverse: bool = False, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, *, cfg: RecurrentAttentionConfig) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#
Execute recurrent attention with stateful computation.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
g – Query-level gating tensor [batch, seq_len, num_heads, head_dim]
g_gamma – Layer-level gating parameter [batch, num_heads]
gk – Key-level gating tensor [batch, seq_len, num_heads, head_dim]
gv – Value-level gating tensor [batch, seq_len, num_heads, head_dim]
softmax_scale – Optional scaling factor for attention scores
initial_state – Initial hidden state [batch, num_heads, head_dim, head_dim]
reverse – If True, process sequence in reverse order
cu_seqlens – Cumulative sequence lengths for variable-length sequences
return_state – If True, return tuple (output, final_state) instead of just output
platform – Optional platform override (“triton”, “pallas”, “cuda”, “xla”)
cfg – Kernel configuration object
- Returns
Attention output [batch, seq_len, num_heads, head_dim] If return_state=True: Tuple of (output, final_state) where final_state
is [batch, num_heads, head_dim, head_dim]
- Return type
If return_state=False
Note
All gating parameters (g, gk, gv, g_gamma) are optional. When provided, they enable more sophisticated gated recurrent mechanisms.
- ejkernel.modules.operations.recurrent.recurrent_attention(query: Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], key: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], value: Float[jaxlib._jax.Array, 'batch seq_len num_kv_heads head_dim'], g: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, g_gamma: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads'] | None = None, gk: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, gv: jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | None = None, initial_state: jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim'] | None = None, cu_seqlens: jaxtyping.Int[jaxlib._jax.Array, 'num_seqs_plus_one'] | None = None, /, *, softmax_scale: float | None = None, reverse: bool = False, return_state: bool = False, platform: Optional[Literal['triton', 'pallas', 'cuda', 'xla', 'auto']] = None, cfg: ejkernel.modules.operations.configs.RecurrentAttentionConfig | None = None) jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'] | tuple[jaxtyping.Float[jaxlib._jax.Array, 'batch seq_len num_heads head_dim'], jaxtyping.Float[jaxlib._jax.Array, 'batch num_heads head_dim head_dim']][source]#
Execute recurrent attention with automatic optimization.
Recurrent attention processes sequences with stateful computation, maintaining hidden states across timesteps for efficient sequential processing.
- Parameters
query – Query tensor [batch, seq_len, num_heads, head_dim]
key – Key tensor [batch, seq_len, num_kv_heads, head_dim]
value – Value tensor [batch, seq_len, num_kv_heads, head_dim]
g – Gating tensor for query [batch, seq_len, num_heads, head_dim]
g_gamma – Gating gamma [batch, num_heads]
gk – Gating tensor for keys [batch, seq_len, num_heads, head_dim]
gv – Gating tensor for values [batch, seq_len, num_heads, head_dim]
softmax_scale – Scaling factor for attention
initial_state – Initial hidden state
reverse – Whether to process sequence in reverse
cu_seqlens – Cumulative sequence lengths for variable-length sequences
return_state – If True, return tuple (output, final_state) instead of just output
platform – Specific platform to use (“triton”, “pallas”, “cuda”, or “xla”)
- Returns
Attention output with same shape as query If return_state=True: Tuple of (output, final_state)
- Return type
If return_state=False
Example
>>> >>> out = recurrent_attention(query, key, value) >>> >>> >>> out = recurrent_attention(query, key, value, g=gates, gk=key_gates, gv=value_gates) >>> >>> >>> out = recurrent_attention(query, key, value, platform="xla")