Skip to content

[Enhancement] Hybrid training in ML2#38

Open
siuwuncheung wants to merge 27 commits into
mainfrom
mix_drm_pinn
Open

[Enhancement] Hybrid training in ML2#38
siuwuncheung wants to merge 27 commits into
mainfrom
mix_drm_pinn

Conversation

@siuwuncheung
Copy link
Copy Markdown
Collaborator

@siuwuncheung siuwuncheung commented Oct 6, 2025

This PR implements multilevel training with first level Deep Ritz and second level PINN.

It allows different epochs, mesh resolutions, Chebyshev frequencies, learning rates for the two levels, and some visualization tools.

python3 pinn/pinn_1d.py --levels 2 --epochs 100 400 --lr 1e-3 --activation tanh --sweeps 1 --plot --hidden_dims 512 512 512 --high_freq 8 --enforce_bc --use_chebyshev_basis --chebyshev_freq_min 1 3 --chebyshev_freq_max 4 4 --nx 256 --loss_type 2

@siuwuncheung siuwuncheung changed the title Hybrid training in ML2 [Enhancement] Hybrid training in ML2 Oct 6, 2025
@siuwuncheung siuwuncheung requested a review from liruipeng October 6, 2025 15:35
@siuwuncheung siuwuncheung added the RFR ready for review label Oct 6, 2025
Comment thread pinn/pinn_1d.py
gate = nn.Parameter(torch.tensor(float(val), dtype=torch.float32), requires_grad=False)
else:
# gate0 = 1.0, others = 0.0
val = 1.0
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this consistent with the comment gate[0] = 1.0, gate[1:] = 0.0?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to uncomment Line 436.

Comment thread pinn/utils.py
parser = argparse.ArgumentParser(description="Train a PINN model.")

parser.add_argument('--nx', type=int, default=128,
parser.add_argument('--nx', type=int, nargs='+', default=[128],
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nx is not necessarily a scalar anymore. what is the meaning of nx now?

Comment thread pinn/pinn_1d.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# %%
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to move this helper functions into another (new) file. Just a minor suggestion.

Copy link
Copy Markdown
Owner

@liruipeng liruipeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me modulo a few minor questions/comments. Sorry for taking so long. Thank you. @siuwuncheung

Comment thread pinn/pinn_1d.py
# Case 2: init_frozen = True
# level 0 → gate = 1.0
# other levels → gate = 0.0
init_value = 1.0 if level_idx == 0 else 0.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be quite difficult for the user to know. I guess maybe you do not want to have too many input parameters, but seems it's good to make init_value to be an input. It also gives some flexibility.

Comment thread pinn/pinn_1d.py
# %%
# Define multievel gates
class GatedLevel(nn.Module):
def __init__(self, level_idx, init_frozen, device="cuda"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this class, the freeze_gate method makes it not trainable, but init_frozen here means it is initialized to be trainable. It is a bit confusing to me. Do I understand correctly that init_frozen is equivalent to "the gate being trainable"? If so, may consider just call the parameter trainable?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a more appropriate way to describe it is init_frozen==False means "the gate being never trainable". But while that might be the intension, seems the current implementation does not prevent that from happening? I mean one could initialize the object with init_frozen==False, and then call unfreeze_gate() later?

Also, is this class being used yet? Looks like MultilevelNN has its own gates there. Maybe this class is an attempt to extract the gate out in preparation for future refactoring?

Copy link
Copy Markdown
Collaborator

@chakshinglee chakshinglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. I have some question about the GatedLevel class, which seems not really being used at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

RFR ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants