Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from defuser.modeling.update_module import update_module
from defuser.utils.common import (
MIN_SUPPORTED_TRANSFORMERS_VERSION,
is_version_at_least,
is_supported_transformers_version,
warn_if_public_api_transformers_unsupported,
)
from packaging import version
import transformers
from logbar import LogBar

Expand Down Expand Up @@ -69,7 +69,7 @@ def replace_fused_blocks(model_type: str) -> bool:
custom_class = getattr(custom_module, custom_class_name)
setattr(orig_module, orig_class_name, custom_class)

if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION):
if is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION):
from transformers import conversion_mapping

if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"):
Expand Down Expand Up @@ -102,8 +102,7 @@ def check_model_compatibility(model: nn.Module) -> bool:
return False

min_ver = MODEL_CONFIG[model_type].get("min_transformers_version")
current_ver = version.parse(transformers.__version__)
if min_ver and current_ver < version.parse(min_ver):
if min_ver and not is_version_at_least(transformers.__version__, min_ver):
logger.warn(
f"Skip conversion for model_type={model_type}: "
f"requires transformers>={min_ver}, current version is {transformers.__version__}."
Expand Down
33 changes: 31 additions & 2 deletions defuser/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,35 @@ class ModuleNameFilter:
negative: tuple[pcre.Pattern, ...]


def _parse_version(value: str | version.Version) -> version.Version:
"""Return a normalized packaging version object."""
if isinstance(value, version.Version):
return value
return version.parse(value)


def is_version_at_least(
installed_version: str | version.Version,
minimum_version: str | version.Version,
) -> bool:
"""Return whether a version meets a minimum, allowing same-release dev snapshots.

Hugging Face main-branch builds report versions like ``5.3.0-dev`` which
packaging normalizes to ``5.3.0.dev0`` and orders before ``5.3.0``. Defuser
treats those dev snapshots as satisfying the corresponding stable floor.
"""
installed = _parse_version(installed_version)
minimum = _parse_version(minimum_version)

if installed >= minimum:
return True

if installed.is_devrelease:
return version.parse(installed.base_version) >= minimum

return False


def env_flag(name: str, default: str | bool | None = "0") -> bool:
"""Return ``True`` when an env var is set to a truthy value."""

Expand All @@ -46,14 +75,14 @@ def is_transformers_version_greater_or_equal_5():
"""Cache the coarse ``transformers>=5`` capability check used by fast paths."""
import transformers

return version.parse(transformers.__version__) >= version.parse("5.0.0")
return is_version_at_least(transformers.__version__, "5.0.0")


def is_supported_transformers_version() -> bool:
"""Return whether the installed transformers version is supported by Defuser's public API."""
import transformers

return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION)
return is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION)


def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool:
Expand Down
6 changes: 2 additions & 4 deletions defuser/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import torch
import transformers
from logbar import LogBar
from packaging import version
from transformers import AutoConfig

from defuser.model_registry import MODEL_CONFIG
from defuser.utils.common import env_flag, warn_if_public_api_transformers_unsupported
from defuser.utils.common import env_flag, is_version_at_least, warn_if_public_api_transformers_unsupported

logger = LogBar(__name__)

Expand Down Expand Up @@ -77,8 +76,7 @@ def pre_check_config(model_name: str | torch.nn.Module):
cfg = MODEL_CONFIG[model_type]

min_ver = cfg.get("min_transformers_version")
tf_ver = version.parse(transformers.__version__)
if min_ver and tf_ver < version.parse(min_ver):
if min_ver and not is_version_at_least(transformers.__version__, min_ver):
return False
try:
file_path = get_file_path_via_model_name(model_name, "model.safetensors.index.json")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "Defuser"
version = "0.0.16"
version = "0.0.17"
description = "Model defuser helper for HF Transformers."
readme = "README.md"
requires-python = ">=3.9"
Expand Down