From d391fcec497293124efbeaac93f4aa9c7cc4ce69 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 23 Mar 2026 08:49:30 +0000 Subject: [PATCH] fix version check --- defuser/defuser.py | 7 +++---- defuser/utils/common.py | 33 +++++++++++++++++++++++++++++++-- defuser/utils/hf.py | 6 ++---- pyproject.toml | 2 +- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/defuser/defuser.py b/defuser/defuser.py index 5bb49f6..8126a16 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -12,10 +12,10 @@ from defuser.modeling.update_module import update_module from defuser.utils.common import ( MIN_SUPPORTED_TRANSFORMERS_VERSION, + is_version_at_least, is_supported_transformers_version, warn_if_public_api_transformers_unsupported, ) -from packaging import version import transformers from logbar import LogBar @@ -69,7 +69,7 @@ def replace_fused_blocks(model_type: str) -> bool: custom_class = getattr(custom_module, custom_class_name) setattr(orig_module, orig_class_name, custom_class) - if version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION): + if is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION): from transformers import conversion_mapping if not hasattr(conversion_mapping, "orig_get_checkpoint_conversion_mapping"): @@ -102,8 +102,7 @@ def check_model_compatibility(model: nn.Module) -> bool: return False min_ver = MODEL_CONFIG[model_type].get("min_transformers_version") - current_ver = version.parse(transformers.__version__) - if min_ver and current_ver < version.parse(min_ver): + if min_ver and not is_version_at_least(transformers.__version__, min_ver): logger.warn( f"Skip conversion for model_type={model_type}: " f"requires transformers>={min_ver}, current version is {transformers.__version__}." diff --git a/defuser/utils/common.py b/defuser/utils/common.py index a1df6cc..55b50bb 100644 --- a/defuser/utils/common.py +++ b/defuser/utils/common.py @@ -28,6 +28,35 @@ class ModuleNameFilter: negative: tuple[pcre.Pattern, ...] +def _parse_version(value: str | version.Version) -> version.Version: + """Return a normalized packaging version object.""" + if isinstance(value, version.Version): + return value + return version.parse(value) + + +def is_version_at_least( + installed_version: str | version.Version, + minimum_version: str | version.Version, +) -> bool: + """Return whether a version meets a minimum, allowing same-release dev snapshots. + + Hugging Face main-branch builds report versions like ``5.3.0-dev`` which + packaging normalizes to ``5.3.0.dev0`` and orders before ``5.3.0``. Defuser + treats those dev snapshots as satisfying the corresponding stable floor. + """ + installed = _parse_version(installed_version) + minimum = _parse_version(minimum_version) + + if installed >= minimum: + return True + + if installed.is_devrelease: + return version.parse(installed.base_version) >= minimum + + return False + + def env_flag(name: str, default: str | bool | None = "0") -> bool: """Return ``True`` when an env var is set to a truthy value.""" @@ -46,14 +75,14 @@ def is_transformers_version_greater_or_equal_5(): """Cache the coarse ``transformers>=5`` capability check used by fast paths.""" import transformers - return version.parse(transformers.__version__) >= version.parse("5.0.0") + return is_version_at_least(transformers.__version__, "5.0.0") def is_supported_transformers_version() -> bool: """Return whether the installed transformers version is supported by Defuser's public API.""" import transformers - return version.parse(transformers.__version__) >= version.parse(MIN_SUPPORTED_TRANSFORMERS_VERSION) + return is_version_at_least(transformers.__version__, MIN_SUPPORTED_TRANSFORMERS_VERSION) def warn_if_public_api_transformers_unsupported(api_name: str, logger) -> bool: diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py index 9846bd6..7997dd7 100644 --- a/defuser/utils/hf.py +++ b/defuser/utils/hf.py @@ -12,11 +12,10 @@ import torch import transformers from logbar import LogBar -from packaging import version from transformers import AutoConfig from defuser.model_registry import MODEL_CONFIG -from defuser.utils.common import env_flag, warn_if_public_api_transformers_unsupported +from defuser.utils.common import env_flag, is_version_at_least, warn_if_public_api_transformers_unsupported logger = LogBar(__name__) @@ -77,8 +76,7 @@ def pre_check_config(model_name: str | torch.nn.Module): cfg = MODEL_CONFIG[model_type] min_ver = cfg.get("min_transformers_version") - tf_ver = version.parse(transformers.__version__) - if min_ver and tf_ver < version.parse(min_ver): + if min_ver and not is_version_at_least(transformers.__version__, min_ver): return False try: file_path = get_file_path_via_model_name(model_name, "model.safetensors.index.json") diff --git a/pyproject.toml b/pyproject.toml index e03bf39..2723b49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.16" +version = "0.0.17" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9"