Skip to content
23 changes: 16 additions & 7 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
Expand Down Expand Up @@ -318,6 +319,7 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
Expand All @@ -327,31 +329,34 @@ class _HubKernelConfig:

# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
version=1,
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
# revision="fake-ops-return-probs",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
version=1,
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
),
}

Expand Down Expand Up @@ -521,6 +526,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)

elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
Expand Down Expand Up @@ -694,7 +703,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
try:
from kernels import get_kernel

kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)


@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.

Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)


@cache
def is_hf_hub_version(operation: str, version: str):
"""
Expand Down