[Enhancement] Hybrid training in ML2#38
Conversation
| gate = nn.Parameter(torch.tensor(float(val), dtype=torch.float32), requires_grad=False) | ||
| else: | ||
| # gate0 = 1.0, others = 0.0 | ||
| val = 1.0 |
There was a problem hiding this comment.
Is this consistent with the comment gate[0] = 1.0, gate[1:] = 0.0?
There was a problem hiding this comment.
I need to uncomment Line 436.
| parser = argparse.ArgumentParser(description="Train a PINN model.") | ||
|
|
||
| parser.add_argument('--nx', type=int, default=128, | ||
| parser.add_argument('--nx', type=int, nargs='+', default=[128], |
There was a problem hiding this comment.
nx is not necessarily a scalar anymore. what is the meaning of nx now?
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
|
|
||
| # %% |
There was a problem hiding this comment.
I would like to move this helper functions into another (new) file. Just a minor suggestion.
liruipeng
left a comment
There was a problem hiding this comment.
Looks good to me modulo a few minor questions/comments. Sorry for taking so long. Thank you. @siuwuncheung
| # Case 2: init_frozen = True | ||
| # level 0 → gate = 1.0 | ||
| # other levels → gate = 0.0 | ||
| init_value = 1.0 if level_idx == 0 else 0.0 |
There was a problem hiding this comment.
this can be quite difficult for the user to know. I guess maybe you do not want to have too many input parameters, but seems it's good to make init_value to be an input. It also gives some flexibility.
| # %% | ||
| # Define multievel gates | ||
| class GatedLevel(nn.Module): | ||
| def __init__(self, level_idx, init_frozen, device="cuda"): |
There was a problem hiding this comment.
for this class, the freeze_gate method makes it not trainable, but init_frozen here means it is initialized to be trainable. It is a bit confusing to me. Do I understand correctly that init_frozen is equivalent to "the gate being trainable"? If so, may consider just call the parameter trainable?
There was a problem hiding this comment.
I guess a more appropriate way to describe it is init_frozen==False means "the gate being never trainable". But while that might be the intension, seems the current implementation does not prevent that from happening? I mean one could initialize the object with init_frozen==False, and then call unfreeze_gate() later?
Also, is this class being used yet? Looks like MultilevelNN has its own gates there. Maybe this class is an attempt to extract the gate out in preparation for future refactoring?
chakshinglee
left a comment
There was a problem hiding this comment.
Overall looks good to me. I have some question about the GatedLevel class, which seems not really being used at the moment.
This PR implements multilevel training with first level Deep Ritz and second level PINN.
It allows different epochs, mesh resolutions, Chebyshev frequencies, learning rates for the two levels, and some visualization tools.
python3 pinn/pinn_1d.py --levels 2 --epochs 100 400 --lr 1e-3 --activation tanh --sweeps 1 --plot --hidden_dims 512 512 512 --high_freq 8 --enforce_bc --use_chebyshev_basis --chebyshev_freq_min 1 3 --chebyshev_freq_max 4 4 --nx 256 --loss_type 2