-
Notifications
You must be signed in to change notification settings - Fork 70
Formation energies in ML potentials #3122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0393e99
8ced228
944536e
9a34449
c868ebb
f0053a1
99208ff
8feff0a
a5607d2
39683be
09a4340
3ca5305
64c104e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from functools import lru_cache, wraps | ||
| from importlib.util import find_spec | ||
| from logging import getLogger | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from ase.units import GPa as _GPa_to_eV_per_A3 | ||
|
|
@@ -51,12 +52,100 @@ def wrapped(*args, **kwargs): | |
| return wrapped | ||
|
|
||
|
|
||
| @lru_cache | ||
| def _get_omat24_references() -> dict[str, float]: | ||
| """ | ||
| Fetch formation energy references for OMAT24-trained models from HuggingFace. | ||
|
|
||
| These references come from https://huggingface.co/facebook/UMA/blob/main/references/form_elem_refs.yaml | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, float] | ||
| Dictionary mapping element symbols to reference energies (eV/atom). | ||
| """ | ||
| import yaml | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| LOGGER.info("Downloading OMAT24 formation energy references from HuggingFace...") | ||
|
|
||
| # Download the form_elem_refs.yaml file from HuggingFace | ||
| refs_file = hf_hub_download( | ||
| repo_id="facebook/UMA", | ||
| filename="references/form_elem_refs.yaml", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stores it in the current working directory. It seems like this should instead be prefaced by Also, are we going to be redownloading the YAML every single time a job runs? That seems inefficient if so. It seems like we should be storing this once in and checking if it's there. |
||
| repo_type="model", | ||
| ) | ||
|
|
||
| # Load and extract the omat references | ||
| with Path.open(refs_file) as f: | ||
| refs_data = yaml.safe_load(f) | ||
|
|
||
| omat_refs = refs_data.get("refs", {}).get("omat", {}) | ||
|
|
||
| if not omat_refs: | ||
| raise ValueError("Could not find 'refs.omat' in the downloaded reference file.") | ||
|
|
||
| LOGGER.info(f"Loaded OMAT24 references for {len(omat_refs)} elements.") | ||
| return omat_refs | ||
|
|
||
|
|
||
| @lru_cache | ||
| def _get_mp20_references() -> dict[str, float]: | ||
| """ | ||
| Load formation energy references for MP-20 compatible models. | ||
|
|
||
| These references come from matbench-discovery repository: | ||
| https://github.com/janosh/matbench-discovery | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, float] | ||
| Dictionary mapping element symbols to reference energies (eV/atom). | ||
| """ | ||
| import gzip | ||
| import json | ||
| from pathlib import Path | ||
|
Comment on lines
+105
to
+107
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In here and the other function, do we need to keep these imports inside the function or can they be global? |
||
|
|
||
| LOGGER.info("Loading MP-20 formation energy references from local file...") | ||
|
|
||
| # Load from local gzipped JSON file | ||
| refs_file = ( | ||
| Path(__file__).parent | ||
| / "references" | ||
| / "2023-02-07-mp-elemental-reference-entries.json.gz" | ||
| ) | ||
|
|
||
| if not refs_file.exists(): | ||
| raise FileNotFoundError( | ||
| f"MP-20 reference file not found at {refs_file}. " | ||
| "Please ensure the file is in src/quacc/recipes/mlp/references/" | ||
| ) | ||
|
|
||
| # Load the gzipped JSON file | ||
| with gzip.open(refs_file, "rt") as f: | ||
| refs_data = json.load(f) | ||
|
|
||
| # Extract element references based on the expected structure | ||
| # The file should contain element references | ||
| if isinstance(refs_data, dict): | ||
| mp20_refs = refs_data | ||
| else: | ||
| raise ValueError( | ||
| f"Unexpected format in MP-20 reference file: {type(refs_data)}" | ||
| ) | ||
|
|
||
| LOGGER.info(f"Loaded MP-20 references for {len(mp20_refs)} elements.") | ||
| return mp20_refs | ||
|
|
||
|
|
||
| @freezeargs | ||
| @lru_cache | ||
| def pick_calculator( | ||
| method: Literal[ | ||
| "mace-mp", "m3gnet", "chgnet", "tensornet", "sevennet", "orb", "fairchem" | ||
| ], | ||
| use_formation_energy: bool = False, | ||
| references: Literal["MP20", "OMAT24"] | None = None, | ||
| **calc_kwargs, | ||
| ) -> BaseCalculator: | ||
| """ | ||
|
|
@@ -71,14 +160,25 @@ def pick_calculator( | |
| ---------- | ||
| method | ||
| Name of the calculator to use. | ||
| use_formation_energy | ||
| If True, wrap the calculator with FormationEnergyCalculator to compute | ||
| formation energies. Requires fairchem-core package to be installed. | ||
| Supported for all calculator types. Default is False. | ||
| references | ||
| Formation energy references to use. Only used if use_formation_energy=True. | ||
| Options: | ||
| - None: Use built-in references from FormationEnergyCalculator (FAIRChem models only) | ||
| - "OMAT24": Use OMAT24 references from https://huggingface.co/facebook/UMA | ||
| - "MP20": Use MP-20 references from matbench-discovery | ||
| Default is None. | ||
| **calc_kwargs | ||
| Custom kwargs for the underlying calculator. Set a value to | ||
| `quacc.Remove` to remove a pre-existing key entirely. | ||
|
|
||
| Returns | ||
| ------- | ||
| BaseCalculator | ||
| The instantiated calculator | ||
| The instantiated calculator (optionally wrapped with FormationEnergyCalculator) | ||
| """ | ||
| import torch | ||
|
|
||
|
|
@@ -125,7 +225,7 @@ def pick_calculator( | |
| from orb_models.forcefield import pretrained | ||
| from orb_models.forcefield.calculator import ORBCalculator | ||
|
|
||
| orb_model = calc_kwargs.get("model", "orb_v2") | ||
| orb_model = calc_kwargs.get("model", "orb_v3_conservative_inf_omat") | ||
| orbff = getattr(pretrained, orb_model)() | ||
| calc = ORBCalculator(model=orbff, **calc_kwargs) | ||
|
|
||
|
|
@@ -139,4 +239,23 @@ def pick_calculator( | |
|
|
||
| calc.parameters["version"] = __version__ | ||
|
|
||
| # Wrap with FormationEnergyCalculator if requested | ||
| if use_formation_energy: | ||
| from fairchem.core.calculate.ase_calculator import FormationEnergyCalculator | ||
|
|
||
| # Determine which reference energies to use | ||
| fe_kwargs = {} | ||
|
|
||
| if references == "OMAT24": | ||
| # Use OMAT24 references from HuggingFace | ||
| fe_kwargs["references"] = _get_omat24_references() | ||
| elif references == "MP20": | ||
| # Use MP-20 references from local file | ||
| fe_kwargs["references"] = _get_mp20_references() | ||
| # If references is None, use built-in references from FormationEnergyCalculator | ||
| # (works for FAIRChem models with task_name specified) | ||
|
|
||
| # Wrap with FormationEnergyCalculator using provided kwargs | ||
| calc = FormationEnergyCalculator(calculator=calc, **fe_kwargs) | ||
|
|
||
| return calc | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something that is shipped with fairchem that we can access? I assume not but just want to check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many HPC clusters do not have access to the external network. Is there a mechanism we can encourage the users to call to run this in advance (e.g. on a login node)?