Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_4_available,
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
Expand Down Expand Up @@ -67,6 +68,7 @@

_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_FLASH_ATTN_4 = is_flash_attn_4_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
Expand Down Expand Up @@ -108,6 +110,16 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if _CAN_USE_FLASH_ATTN_4:
try:
from flash_attn.cute import flash_attn_func as flash_attn_4_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_4 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_4 = False
flash_attn_4_func = None
else:
flash_attn_4_func = None

if _CAN_USE_AITER_ATTN:
try:
from aiter import flash_attn_func as aiter_flash_attn_func
Expand Down Expand Up @@ -230,6 +242,7 @@ class AttentionBackendName(str, Enum):
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_4 = "_flash_4"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
Expand Down Expand Up @@ -521,6 +534,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)

elif backend == AttentionBackendName._FLASH_4:
if not _CAN_USE_FLASH_ATTN_4:
raise RuntimeError(
f"Flash Attention 4 backend '{backend.value}' is not usable because flash_attn.cute module is not available. Please install flash-attn>=4.0."
)

elif backend in [
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_VARLEN_HUB,
Expand Down Expand Up @@ -2719,6 +2738,36 @@ def _flash_attention_4_hub(
return out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_4,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")

out = flash_attn_4_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
is_bs4_available,
is_cosmos_guardrail_available,
is_flash_attn_3_available,
is_flash_attn_4_available,
is_flash_attn_available,
is_flash_attn_version,
is_flax_available,
Expand Down
13 changes: 13 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")

# FA4 uses flash_attn.cute module (part of flash-attn >= 4.0)
_flash_attn_4_available = False
try:
from flash_attn.cute import flash_attn_func as _fa4_test
_flash_attn_4_available = True
del _fa4_test
except (ImportError, ModuleNotFoundError):
pass
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
Expand Down Expand Up @@ -409,6 +418,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available


def is_flash_attn_4_available():
return _flash_attn_4_available


def is_aiter_available():
return _aiter_available

Expand Down