Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
abb9569
adding notebook for memory profiling
wiederm Oct 4, 2024
01164dd
notebook and module to profile GPU memory utilization using PyTorch p…
Oct 7, 2024
0bba855
Merge branch 'main' into memory-profiling
wiederm Oct 7, 2024
4a2e6ef
optimize memory allocation in the ANI network architecture
Oct 7, 2024
e543b0f
minor modifications to schnet, using in place operation where possible
Oct 7, 2024
511654c
small modifications for PaiNN
Oct 7, 2024
269f4b3
remove gradient calculation for parameters in inference mode
Oct 7, 2024
afe81bf
Merge branch 'main' into memory-profiling
wiederm Oct 10, 2024
98aa70e
Merge branch 'main' into memory-profiling
wiederm Oct 10, 2024
ef8ff17
fix bug
Oct 10, 2024
ea5cf91
upload profiling notbooks
Oct 10, 2024
64d81ed
please the linter
wiederm Oct 10, 2024
44f32f2
Merge branch 'main' into memory-profiling
wiederm Oct 10, 2024
3731723
add test for profiling functions
Oct 11, 2024
4ca1f84
add tests
wiederm Oct 14, 2024
9ef1fbb
Merge branch 'memory-profiling' of https://github.com/choderalab/mode…
wiederm Oct 14, 2024
d47835c
import openmmtools
wiederm Oct 14, 2024
e1af735
skip the profining test if cuda is not available
wiederm Oct 14, 2024
08ca1bf
Merge branch 'main' into memory-profiling
wiederm Oct 14, 2024
279e4bf
Merge branch 'main' into memory-profiling
wiederm Oct 14, 2024
b8fdf38
update schnet
wiederm Oct 14, 2024
3e5b492
Merge branch 'memory-profiling' of https://github.com/choderalab/mode…
wiederm Oct 14, 2024
68687c4
bugfix (we need to retain graph if we want to calculate high order de…
wiederm Oct 15, 2024
df04e2b
move functions from notebook to package
wiederm Oct 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- pydantic>=2
- ray-all
- graphviz
- openmmtools

# Testing
- pytest>=2.1
Expand All @@ -39,6 +40,5 @@ dependencies:
- pytorch2jax
- git+https://github.com/ArnNag/sake.git@nanometer
- flax
- torch
- pytest-xdist

4 changes: 2 additions & 2 deletions devtools/conda-envs/test_env_mac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies:
- flax
- pydantic>=2.0
- graphviz
-

- openmmtools
# Testing
- pytest>=2.1
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The best way to get started is to read the :doc:`getting_started` guide, which o
inference
for_developer
tuning
profiling
api


Expand Down
7 changes: 7 additions & 0 deletions docs/profiling.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Profiling
================

Profiling: Overview
------------------------------------------

It is common to profile models to identify bottlenecks and optimize performance. *Modelforge* provides a simple interface to profile models using the `torch.profiler` module. The profiler can be used to profile the forward pass, backward pass, or both, and can be used to profile the model on a single batch or multiple batches.
4 changes: 1 addition & 3 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
using a neural network model.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Tuple
from typing import Dict, Tuple, List

import torch
from loguru import logger as log
Expand Down
8 changes: 7 additions & 1 deletion modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.nn import Module
from modelforge.potential.neighbors import PairlistData

from modelforge.dataset.dataset import DatasetParameters, NNPInput, NNPInputTuple
from modelforge.dataset.dataset import DatasetParameters, NNPInputTuple
from modelforge.potential.parameters import (
AimNet2Parameters,
ANI2xParameters,
Expand Down Expand Up @@ -593,6 +593,12 @@ def generate_potential(
neighborlist_strategy=inference_neighborlist_strategy,
verlet_neighborlist_skin=verlet_neighborlist_skin,
)
# Disable gradients for model parameters
for param in model.parameters():
param.requires_grad = False
# Set model to eval
model.eval()

if simulation_environment == "JAX":
return PyTorch2JAXConverter().convert_to_jax_model(model)
else:
Expand Down
4 changes: 2 additions & 2 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def forward(

# featurize pairwise distances using radial basis functions (RBF)
f_ij = self.radial_symmetry_function_module(d_ij)
f_ij_cut = self.cutoff_module(d_ij)

# Apply the filter network and cutoff function
filters = torch.mul(self.filter_net(f_ij), f_ij_cut)
filters = torch.mul(self.filter_net(f_ij), self.cutoff_module(d_ij))

# depending on whether we share filters or not filters have different
# shape at dim=1 (dim=0 is always the number of atom pairs) if we share
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
g = self.attention_mask(data["f_ij"])
# calculate the updated embedding for atom j
embedding_atom_j = self.activation_function(
self.interaction_j(data["atomic_embedding"][idx_j])
self.interaction_j(data["atomic_embedding"])[idx_j]
)
updated_embedding_atom_j = torch.mul(
g, embedding_atom_j
Expand Down
78 changes: 56 additions & 22 deletions modelforge/potential/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,37 +151,71 @@ def forward(self, r_ij: torch.Tensor) -> torch.Tensor:
return sub_aev

def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
"""
Compute the angular subAEV terms of the center atom given neighbor
pairs.

This correspond to equation (4) in the ANI paper. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)

Parameters
----------
vectors12: torch.Tensor
Pairwise distance vectors. Shape: [2, n_pairs, 3]

Returns
-------
torch.Tensor
Angular subAEV terms. Shape: [n_pairs, ShfZ_size * ShfA_size]

"""
vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances12 = vectors12.norm(2, dim=-5)
# vectors12: (2, n_pairs, 3)
distances12 = vectors12.norm(p=2, dim=-1) # Shape: (2, n_pairs)
distances_sum = distances12.sum(dim=0) / 2 # Shape: (n_pairs,)
fcj12 = self.cosine_cutoff(distances12) # Shape: (2, n_pairs)
fcj12_prod = fcj12.prod(dim=0) # Shape: (n_pairs,)

# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
# cos_angles: (n_pairs,)
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(
vectors12[0], vectors12[1], dim=-5
vectors12[0], vectors12[1], dim=-1
)
angles = torch.acos(cos_angles)
fcj12 = self.cosine_cutoff(distances12)
factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta
angles = torch.acos(cos_angles) # Shape: (n_pairs,)

# Prepare shifts for broadcasting
angles = angles.unsqueeze(-1) # Shape: (n_pairs, 1)
distances_sum = distances_sum.unsqueeze(-1) # Shape: (n_pairs, 1)

# Compute factor1
delta_angles = angles - self.ShfZ.view(1, -1) # Shape: (n_pairs, ShfZ_size)
factor1 = (
(1 + torch.cos(delta_angles)) / 2
) ** self.Zeta # Shape: (n_pairs, ShfZ_size)

# Compute factor2
delta_distances = distances_sum - self.ShfA.view(
1, -1
) # Shape: (n_pairs, ShfA_size)
factor2 = torch.exp(
-self.EtaA * (distances12.sum(0) / 2 - self.ShfA) ** 2
).unsqueeze(-1)
factor2 = factor2.squeeze(4).squeeze(3)
ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
-self.EtaA * delta_distances**2
) # Shape: (n_pairs, ShfA_size)

# Compute the outer product of factor1 and factor2 efficiently
# fcj12_prod: (n_pairs, 1, 1)
fcj12_prod = fcj12_prod.unsqueeze(-1).unsqueeze(-1) # Shape: (n_pairs, 1, 1)

# factor1: (n_pairs, ShfZ_size, 1)
factor1 = factor1.unsqueeze(-1)
# factor2: (n_pairs, 1, ShfA_size)
factor2 = factor2.unsqueeze(-2)

# Compute ret: (n_pairs, ShfZ_size, ShfA_size)
ret = 2 * fcj12_prod * factor1 * factor2

# Flatten the last two dimensions to get the final subAEV
# ret: (n_pairs, ShfZ_size * ShfA_size)
ret = ret.reshape(distances12.size(dim=1), -1)

return ret


import math
Expand Down
15 changes: 7 additions & 8 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ def compute_properties(
# Compute the atomic representation
representation = self.schnet_representation_module(data, pairlist_output)
atomic_embedding = representation["atomic_embedding"]

f_ij = representation["f_ij"]
f_cutoff = representation["f_cutoff"]
# Apply interaction modules to update the atomic embedding
for interaction in self.interaction_modules:
v = interaction(
atomic_embedding = atomic_embedding + interaction(
atomic_embedding,
pairlist_output,
representation["f_ij"],
representation["f_cutoff"],
f_ij,
f_cutoff,
)
atomic_embedding = atomic_embedding + v # Update atomic features

return {
"per_atom_scalar_representation": atomic_embedding,
Expand Down Expand Up @@ -293,14 +293,13 @@ def forward(

# Generate interaction filters based on radial basis functions
W_ij = self.filter_network(f_ij.squeeze(1))
W_ij = W_ij * f_ij_cutoff
W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters]

# Perform continuous-filter convolution
x_j = atomic_embedding[idx_j]
x_ij = x_j * W_ij # Element-wise multiplication

out = torch.zeros_like(atomic_embedding)
out.scatter_add_(
out = torch.zeros_like(atomic_embedding).scatter_add_(
0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij
) # Aggregate per-atom pair to per-atom

Expand Down
56 changes: 56 additions & 0 deletions modelforge/tests/test_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import pytest


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_profiling_function():
from modelforge.tests.helper_functions import setup_potential_for_test
import torch
from modelforge.utils.profiling import (
start_record_memory_history,
export_memory_snapshot,
stop_record_memory_history,
setup_waterbox_testsystem,
)

# define the potential, device and precision
potential_name = "tensornet"
precision = torch.float32
device = "cuda"

# setup the input and model
nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision)
model = setup_potential_for_test(
potential_name,
"inference",
potential_seed=42,
use_training_mode_neighborlist=True,
simulation_environment="PyTorch",
).to(device, precision)
# Disable gradients for model parameters
for param in model.parameters():
param.requires_grad = False
# Set model to eval
model.eval()

# this is the function that will be profiled
def loop_to_record():
for _ in range(5):
# perform the forward pass through each of the models
r = model(nnp_input)["per_molecule_energy"]
# Compute the gradient (forces) from the predicted energies
grad = torch.autograd.grad(
r,
nnp_input.positions,
grad_outputs=torch.ones_like(r),
create_graph=False,
retain_graph=False,
)[0]

# Start recording memory snapshot history
start_record_memory_history()
loop_to_record()
# Create the memory snapshot file
export_memory_snapshot()
# Stop recording memory snapshot history
stop_record_memory_history()
12 changes: 12 additions & 0 deletions modelforge/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@

"""

MESSAGES[
"openmmtools"
] = """

A batteries-included toolkit for the GPU-accelerated OpenMM molecular simulation engine.

OpenMMTools can be installed via conda:

conda install conda-forge::openmmtools

"""


def import_(module: str):
"""Import a module or print a descriptive message and raise an ImportError
Expand Down
Loading