diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 375abb24d131..6bf6af95f4f9 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -35,6 +35,7 @@ is_aiter_available, is_aiter_version, is_flash_attn_3_available, + is_flash_attn_4_available, is_flash_attn_available, is_flash_attn_version, is_kernels_available, @@ -67,6 +68,7 @@ _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_FLASH_ATTN_4 = is_flash_attn_4_available() _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) @@ -108,6 +110,16 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None +if _CAN_USE_FLASH_ATTN_4: + try: + from flash_attn.cute import flash_attn_func as flash_attn_4_func + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"flash_attn_4 failed to import: {e}. Falling back to native attention.") + _CAN_USE_FLASH_ATTN_4 = False + flash_attn_4_func = None +else: + flash_attn_4_func = None + if _CAN_USE_AITER_ATTN: try: from aiter import flash_attn_func as aiter_flash_attn_func @@ -230,6 +242,7 @@ class AttentionBackendName(str, Enum): FLASH_VARLEN = "flash_varlen" FLASH_VARLEN_HUB = "flash_varlen_hub" FLASH_4_HUB = "flash_4_hub" + _FLASH_4 = "_flash_4" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" @@ -521,6 +534,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) + elif backend == AttentionBackendName._FLASH_4: + if not _CAN_USE_FLASH_ATTN_4: + raise RuntimeError( + f"Flash Attention 4 backend '{backend.value}' is not usable because flash_attn.cute module is not available. Please install flash-attn>=4.0." + ) + elif backend in [ AttentionBackendName.FLASH_HUB, AttentionBackendName.FLASH_VARLEN_HUB, @@ -2719,6 +2738,36 @@ def _flash_attention_4_hub( return out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_4, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _flash_attention_4( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + scale: float | None = None, + is_causal: bool = False, + return_lse: bool = False, + _parallel_config: "ParallelConfig" | None = None, +) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 4.") + + out = flash_attn_4_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + ) + if isinstance(out, tuple): + return (out[0], out[1]) if return_lse else out[0] + return out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..4cf64b0d69c1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -74,6 +74,7 @@ is_bs4_available, is_cosmos_guardrail_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flash_attn_available, is_flash_attn_version, is_flax_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..07a428b5a647 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -227,6 +227,15 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") + +# FA4 uses flash_attn.cute module (part of flash-attn >= 4.0) +_flash_attn_4_available = False +try: + from flash_attn.cute import flash_attn_func as _fa4_test + _flash_attn_4_available = True + del _fa4_test +except (ImportError, ModuleNotFoundError): + pass _aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True) _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) @@ -409,6 +418,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_flash_attn_4_available(): + return _flash_attn_4_available + + def is_aiter_available(): return _aiter_available