ejkernel.kernels._triton.flash_attention._triton_impl_bwd#
- ejkernel.kernels._triton.flash_attention._triton_impl_bwd.config_prune_kernel(configs: list[triton.runtime.autotuner.Config], named_args: dict[str, Any], **kwargs: Any) list[triton.runtime.autotuner.Config][source]#
Prune autotuning configurations for backward pass kernel.
Filters out configurations where block dimensions exceed sequence lengths. Falls back to small default configs if all configs are pruned.
- Parameters
configs – List of triton autotuning configurations
named_args – Dictionary with kernel arguments including QSeq and KSeq
**kwargs – Additional unused arguments
- Returns
Valid configurations for the given problem size
- Return type
list[Config]