Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/bioemu/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from pathlib import Path

Expand All @@ -12,6 +13,8 @@
from bioemu.models import DiGConditionalScoreModel
from bioemu.sde_lib import SDE

logger = logging.getLogger(__name__)


def maybe_download_checkpoint(
*,
Expand Down Expand Up @@ -54,18 +57,43 @@ def maybe_download_checkpoint(
return str(ckpt_path), str(model_config_path)


def _is_legacy_checkpoint(path: str | Path) -> bool:
"""Check if *path* points to a legacy weight file (not a ``from_pretrained`` directory)."""
p = Path(path)
return p.is_file()


def load_model(ckpt_path: str | Path, model_config_path: str | Path) -> DiGConditionalScoreModel:
"""Load score model from checkpoint and config."""
assert os.path.isfile(ckpt_path), f"Checkpoint {ckpt_path} not found"
assert os.path.isfile(model_config_path), f"Model config {model_config_path} not found"
"""Load score model from checkpoint and config.

with open(model_config_path) as f:
model_config = yaml.safe_load(f)
Supports two formats:
1. Legacy format: ckpt_path is a ``.ckpt`` file and model_config_path is
a YAML file with Hydra ``_target_`` entries.
2. Pretrained format: ckpt_path is anything that
``DiGConditionalScoreModel.from_pretrained`` accepts — a local
directory (with ``config.json`` + ``model.safetensors``) **or** a
Hub repo ID (e.g. ``"your-org/your-bioemu-model"``). Resolution of
local-vs-Hub is handled by ``huggingface_hub`` itself.

In case 2, ``model_config_path`` is only used by ``load_sdes``
and is not needed for loading the model itself.
"""
if _is_legacy_checkpoint(ckpt_path):
logger.info("Loading model from legacy checkpoint: %s", ckpt_path)
assert os.path.isfile(model_config_path), f"Model config {model_config_path} not found"

with open(model_config_path) as f:
model_config = yaml.safe_load(f)

model_state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
score_model: DiGConditionalScoreModel = hydra.utils.instantiate(
model_config["score_model"]
)
score_model.load_state_dict(model_state)
return score_model

model_state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
score_model: DiGConditionalScoreModel = hydra.utils.instantiate(model_config["score_model"])
score_model.load_state_dict(model_state)
return score_model
logger.info("Loading model via from_pretrained: %s", ckpt_path)
return DiGConditionalScoreModel.from_pretrained(str(ckpt_path))


def load_sdes(
Expand Down
9 changes: 8 additions & 1 deletion src/bioemu/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch_geometric.utils import to_dense_adj, to_dense_batch

Expand Down Expand Up @@ -326,7 +327,13 @@ def __str__(self) -> str:
return super().__str__() + f"\nTrainable parameters: {params}"


class DiGConditionalScoreModel(torch.nn.Module):
class DiGConditionalScoreModel(
torch.nn.Module,
PyTorchModelHubMixin,
library_name="bioemu",
repo_url="https://github.com/microsoft/bioemu",
tags=["protein-structure", "diffusion", "molecular-dynamics"],
):
"""Wrapper to convert the DiG nn.Module neural network that operates directly on position
and rotation tensors into a ScoreModel that operates on ChemGraph objects.
"""
Expand Down