ejkernel.kernels._triton.flash_attention._triton_impl_bwd

ejkernel.kernels._triton.flash_attention._triton_impl_bwd#

ejkernel.kernels._triton.flash_attention._triton_impl_bwd.config_prune_kernel(configs: list[triton.runtime.autotuner.Config], named_args: dict[str, Any], **kwargs: Any) list[triton.runtime.autotuner.Config][source]#

Prune autotuning configurations for backward pass kernel.

Filters out configurations where block dimensions exceed sequence lengths. Falls back to small default configs if all configs are pruned.

Parameters
  • configs – List of triton autotuning configurations

  • named_args – Dictionary with kernel arguments including QSeq and KSeq

  • **kwargs – Additional unused arguments

Returns

Valid configurations for the given problem size

Return type

list[Config]