Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "SOAP"]
path = SOAP
url = https://github.com/nikhilvyas/SOAP
[submodule "optim_dualcone"]
path = optim_dualcone
url = https://github.com/youngsikhwang/Dual-Cone-Gradient-Descent
1 change: 1 addition & 0 deletions optim_dualcone
Submodule optim_dualcone added at 0d8935
37 changes: 36 additions & 1 deletion pinn/pinn_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@
import numpy as np
from enum import Enum
from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames, is_notebook, cleanfiles
import sys
sys.path.insert(0, '..')
from SOAP.soap import SOAP
from optim_dualcone.dcgd import DCGD

# torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -280,6 +283,8 @@ def __init__(self, loss_type, loss_func=nn.MSELoss(), bc_weight=1.0):
self.name = "Super Loss"
elif self.type == 0:
self.name = "PINN Loss"
elif self.type == 1:
self.name = "DRM Loss"
else:
raise ValueError(f"Unknown loss type: {self.type}")
self.bc_weight = bc_weight
Expand Down Expand Up @@ -311,11 +316,41 @@ def pinn_loss(self, model, mesh, loss_func):

return loss

def drm_loss(self, model, mesh: Mesh):
"""Deep Ritz Method loss"""
xs = mesh.x_train.requires_grad_(True)
u = model(xs)

grad_u_pred = torch.autograd.grad(u, xs,
grad_outputs=torch.ones_like(u),
create_graph=True)[0]

u_pred_sq = torch.sum(u**2, dim=1, keepdim=True)
grad_u_pred_sq = torch.sum(grad_u_pred**2, dim=1, keepdim=True)

f_val = mesh.pde.f(xs)
fu_prod = f_val * u

integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1]
loss = torch.mean(integrand_values)

# Boundary loss
u_bc = u[[0,-1]]
u_ex_bc = mesh.u_ex[[0,-1]]
loss_b = self.loss_func(u_bc, u_ex_bc)
loss += self.bc_weight * loss_b


xs.requires_grad_(False) # Disable gradient tracking for x
return loss

def loss(self, model, mesh):
if self.type == -1:
loss_value = self.super_loss(model=model, mesh=mesh, loss_func=self.loss_func)
elif self.type == 0:
loss_value = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func)
elif self.type == 1:
loss_value = self.drm_loss(model=model, mesh=mesh)
else:
raise ValueError(f"Unknown loss type: {self.type}")
return loss_value
Expand Down Expand Up @@ -463,4 +498,4 @@ def main(args=None):
import sys
sys.exit(err)
except SystemExit:
pass # Prevent traceback in Jupyter or VS Code
pass # Prevent traceback in Jupyter or VS Code