Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions src/quacc/recipes/mlp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache, wraps
from importlib.util import find_spec
from logging import getLogger
from pathlib import Path
from typing import TYPE_CHECKING

from ase.units import GPa as _GPa_to_eV_per_A3
Expand Down Expand Up @@ -51,12 +52,100 @@ def wrapped(*args, **kwargs):
return wrapped


@lru_cache
def _get_omat24_references() -> dict[str, float]:
"""
Fetch formation energy references for OMAT24-trained models from HuggingFace.

These references come from https://huggingface.co/facebook/UMA/blob/main/references/form_elem_refs.yaml

Returns
-------
dict[str, float]
Dictionary mapping element symbols to reference energies (eV/atom).
"""
import yaml
from huggingface_hub import hf_hub_download

LOGGER.info("Downloading OMAT24 formation energy references from HuggingFace...")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many HPC clusters do not have access to the external network. Is there a mechanism we can encourage the users to call to run this in advance (e.g. on a login node)?


# Download the form_elem_refs.yaml file from HuggingFace
refs_file = hf_hub_download(
repo_id="facebook/UMA",
filename="references/form_elem_refs.yaml",
Copy link
Copy Markdown
Member

@Andrew-S-Rosen Andrew-S-Rosen Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This stores it in the current working directory. It seems like this should instead be prefaced by Path(__file__).parent so it's stored in a fixed location.

Also, are we going to be redownloading the YAML every single time a job runs? That seems inefficient if so. It seems like we should be storing this once in and checking if it's there.

repo_type="model",
)

# Load and extract the omat references
with Path.open(refs_file) as f:
refs_data = yaml.safe_load(f)

omat_refs = refs_data.get("refs", {}).get("omat", {})

if not omat_refs:
raise ValueError("Could not find 'refs.omat' in the downloaded reference file.")

LOGGER.info(f"Loaded OMAT24 references for {len(omat_refs)} elements.")
return omat_refs


@lru_cache
def _get_mp20_references() -> dict[str, float]:
"""
Load formation energy references for MP-20 compatible models.

These references come from matbench-discovery repository:
https://github.com/janosh/matbench-discovery

Returns
-------
dict[str, float]
Dictionary mapping element symbols to reference energies (eV/atom).
"""
import gzip
import json
from pathlib import Path
Comment on lines +105 to +107
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In here and the other function, do we need to keep these imports inside the function or can they be global?


LOGGER.info("Loading MP-20 formation energy references from local file...")

# Load from local gzipped JSON file
refs_file = (
Path(__file__).parent
/ "references"
/ "2023-02-07-mp-elemental-reference-entries.json.gz"
)

if not refs_file.exists():
raise FileNotFoundError(
f"MP-20 reference file not found at {refs_file}. "
"Please ensure the file is in src/quacc/recipes/mlp/references/"
)

# Load the gzipped JSON file
with gzip.open(refs_file, "rt") as f:
refs_data = json.load(f)

# Extract element references based on the expected structure
# The file should contain element references
if isinstance(refs_data, dict):
mp20_refs = refs_data
else:
raise ValueError(
f"Unexpected format in MP-20 reference file: {type(refs_data)}"
)

LOGGER.info(f"Loaded MP-20 references for {len(mp20_refs)} elements.")
return mp20_refs


@freezeargs
@lru_cache
def pick_calculator(
method: Literal[
"mace-mp", "m3gnet", "chgnet", "tensornet", "sevennet", "orb", "fairchem"
],
use_formation_energy: bool = False,
references: Literal["MP20", "OMAT24"] | None = None,
**calc_kwargs,
) -> BaseCalculator:
"""
Expand All @@ -71,14 +160,25 @@ def pick_calculator(
----------
method
Name of the calculator to use.
use_formation_energy
If True, wrap the calculator with FormationEnergyCalculator to compute
formation energies. Requires fairchem-core package to be installed.
Supported for all calculator types. Default is False.
references
Formation energy references to use. Only used if use_formation_energy=True.
Options:
- None: Use built-in references from FormationEnergyCalculator (FAIRChem models only)
- "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA
- "MP20": Use MP-20 references from matbench-discovery
Default is None.
**calc_kwargs
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely.

Returns
-------
BaseCalculator
The instantiated calculator
The instantiated calculator (optionally wrapped with FormationEnergyCalculator)
"""
import torch

Expand Down Expand Up @@ -125,7 +225,7 @@ def pick_calculator(
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator

orb_model = calc_kwargs.get("model", "orb_v2")
orb_model = calc_kwargs.get("model", "orb_v3_conservative_inf_omat")
orbff = getattr(pretrained, orb_model)()
calc = ORBCalculator(model=orbff, **calc_kwargs)

Expand All @@ -139,4 +239,23 @@ def pick_calculator(

calc.parameters["version"] = __version__

# Wrap with FormationEnergyCalculator if requested
if use_formation_energy:
from fairchem.core.calculate.ase_calculator import FormationEnergyCalculator

# Determine which reference energies to use
fe_kwargs = {}

if references == "OMAT24":
# Use OMAT24 references from HuggingFace
fe_kwargs["references"] = _get_omat24_references()
elif references == "MP20":
# Use MP-20 references from local file
fe_kwargs["references"] = _get_mp20_references()
# If references is None, use built-in references from FormationEnergyCalculator
# (works for FAIRChem models with task_name specified)

# Wrap with FormationEnergyCalculator using provided kwargs
calc = FormationEnergyCalculator(calculator=calc, **fe_kwargs)

return calc
60 changes: 50 additions & 10 deletions src/quacc/recipes/mlp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from quacc import job
from quacc.recipes.mlp._base import pick_calculator
Expand All @@ -11,7 +11,7 @@
from quacc.utils.dicts import recursive_dict_merge

if TYPE_CHECKING:
from typing import Any, Literal
from typing import Literal

from ase.atoms import Atoms

Expand All @@ -23,6 +23,8 @@ def static_job(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb", "fairchem"],
additional_fields: dict[str, Any] | None = None,
use_formation_energy: bool = False,
references: Literal["MP20", "OMAT24"] | None = None,
**calc_kwargs,
) -> RunSchema:
"""
Expand All @@ -36,11 +38,24 @@ def static_job(
Universal ML interatomic potential method to use
additional_fields
Additional fields to add to the results dictionary.
use_formation_energy
If True, wrap the calculator with FormationEnergyCalculator to compute
formation energies. Requires fairchem-core package to be installed.
Supported for all methods. Default is False. The formation energy is
returned in eV per formula unit (not eV/atom).
references
Formation energy references to use. Only used if use_formation_energy=True.
Options:
- None: Use built-in references from FormationEnergyCalculator (FAIRChem models only)
- "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA
- "MP20": Use MP-20 references from matbench-discovery
Default is None.
**calc_kwargs
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
`matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`,
`quacc.Remove` to remove a pre-existing key entirely. For a list of
available keys, refer to the `mace.calculators.mace_mp`,
`chgnet.model.dynamics.CHGNetCalculator`, `matgl.ext.ase.M3GNetCalculator`,
`sevenn.sevennet_calculator.SevenNetCalculator`,
`orb_models.forcefield.calculator.ORBCalculator`,
`fairchem.core.FAIRChemCalculator` calculators.

Expand All @@ -50,7 +65,12 @@ def static_job(
Dictionary of results from [quacc.schemas.ase.Summarize.run][].
See the type-hint for the data structure.
"""
calc = pick_calculator(method, **calc_kwargs)
calc = pick_calculator(
method,
use_formation_energy=use_formation_energy,
references=references,
**calc_kwargs,
)
final_atoms = Runner(atoms, calc).run_calc()
return Summarize(
additional_fields={"name": f"{method} Static"} | (additional_fields or {})
Expand All @@ -64,6 +84,8 @@ def relax_job(
relax_cell: bool = False,
opt_params: OptParams | None = None,
additional_fields: dict[str, Any] | None = None,
use_formation_energy: bool = False,
references: Literal["MP20", "OMAT24"] | None = None,
**calc_kwargs,
) -> OptSchema:
"""
Expand All @@ -82,11 +104,24 @@ def relax_job(
of available keys, refer to [quacc.runners.ase.Runner.run_opt][].
additional_fields
Additional fields to add to the results dictionary.
use_formation_energy
If True, wrap the calculator with FormationEnergyCalculator to compute
formation energies. Requires fairchem-core package to be installed.
Supported for all methods. Default is False. The formation energy is
returned in eV per formula unit (not eV/atom).
references
Formation energy references to use. Only used if use_formation_energy=True.
Options:
- None: Use built-in references from FormationEnergyCalculator (FAIRChem models only)
- "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA
- "MP20": Use MP-20 references from matbench-discovery
Default is None.
**calc_kwargs
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
`matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`,
`quacc.Remove` to remove a pre-existing key entirely. For a list of
available keys, refer to the `mace.calculators.mace_mp`,
`chgnet.model.dynamics.CHGNetCalculator`, `matgl.ext.ase.M3GNetCalculator`,
`sevenn.sevennet_calculator.SevenNetCalculator`,
`orb_models.forcefield.calculator.ORBCalculator`,
`fairchem.core.FAIRChemCalculator` calculators.

Expand All @@ -99,7 +134,12 @@ def relax_job(
opt_defaults = {"fmax": 0.05}
opt_flags = recursive_dict_merge(opt_defaults, opt_params)

calc = pick_calculator(method, **calc_kwargs)
calc = pick_calculator(
method,
use_formation_energy=use_formation_energy,
references=references,
**calc_kwargs,
)

dyn = Runner(atoms, calc).run_opt(relax_cell=relax_cell, **opt_flags)

Expand Down
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that is shipped with fairchem that we can access? I assume not but just want to check.

Binary file not shown.
Loading
Loading