diff --git a/pinn/pinn_1d.py b/pinn/pinn_1d.py index a889101..1f06a3e 100644 --- a/pinn/pinn_1d.py +++ b/pinn/pinn_1d.py @@ -46,9 +46,12 @@ import torch.nn as nn import torch.optim as optim import numpy as np +import itertools from enum import Enum +from typing import Union, Tuple, Callable from utils import parse_args, get_activation, print_args, save_frame, make_video_from_frames -from utils import is_notebook, cleanfiles, fourier_analysis, get_scheduler_generator, scheduler_step +from utils import is_notebook, cleanfiles, get_scheduler_generator, scheduler_step +from utils import error_analysis, fourier_analysis, plot_error_evolution, plot_coefficient_evolution from cheby import generate_chebyshev_features from bc import get_d_func, get_g0_func from datetime import datetime @@ -57,6 +60,147 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# %% +# Helper functions from the new BC implementation +def _calculate_laplacian_1d(func: Callable[[torch.Tensor], torch.Tensor], x_val: float) -> torch.Tensor: + x_tensor = torch.tensor([[x_val]], dtype=torch.float32, requires_grad=True) + u = func(x_tensor) + grad_u = torch.autograd.grad(u, x_tensor, grad_outputs=torch.ones_like(u), create_graph=True, retain_graph=True)[0] + laplacian_u = torch.autograd.grad(grad_u, x_tensor, grad_outputs=torch.ones_like(grad_u), create_graph=False, retain_graph=False)[0] + return laplacian_u + +def get_g0_func( + u_exact_func: Callable[[torch.Tensor], torch.Tensor], + domain_dim: int, + domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]], + g0_type: str = "multilinear" +) -> Callable[[torch.Tensor], torch.Tensor]: + domain_bounds_tuple = domain_bounds + if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)): + domain_bounds_tuple = (domain_bounds,) + min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32) + max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32) + + if g0_type == "hermite_cubic_2nd_deriv": + if domain_dim != 1: raise ValueError("Hermite cubic interpolation with 2nd derivatives is only supported for 1D problems.") + x0, x1 = min_bounds.item(), max_bounds.item() + h = x1 - x0 + u_x0 = u_exact_func(torch.tensor([[x0]], dtype=torch.float32)).item() + u_x1 = u_exact_func(torch.tensor([[x1]], dtype=torch.float32)).item() + u_prime_prime_x0 = _calculate_laplacian_1d(u_exact_func, x0).item() + u_prime_prime_x1 = _calculate_laplacian_1d(u_exact_func, x1).item() + a3 = (u_prime_prime_x1 - u_prime_prime_x0) / (6 * h) + a2 = u_prime_prime_x0 / 2 - 3 * a3 * x0 + a1 = (u_x1 - u_x0) / h - a2 * (x1 + x0) - a3 * (x1**2 + x1 * x0 + x0**2) + a0 = u_x0 - a1 * x0 - a2 * x0**2 - a3 * x0**3 + coeffs = torch.tensor([a0, a1, a2, a3], dtype=torch.float32) + + def g0_hermite_cubic_val(x: torch.Tensor) -> torch.Tensor: + x_flat = x[:, 0] + g0_vals = coeffs[0] + coeffs[1] * x_flat + coeffs[2] * (x_flat**2) + coeffs[3] * (x_flat**3) + return g0_vals.unsqueeze(1) + return g0_hermite_cubic_val + + elif g0_type == "multilinear": + boundary_values = {} + dim_ranges = [[min_bounds[d].item(), max_bounds[d].item()] for d in range(domain_dim)] + for corner_coords in itertools.product(*dim_ranges): + corner_coords_tensor = torch.tensor(corner_coords, dtype=torch.float32).unsqueeze(0) + with torch.no_grad(): + boundary_values[corner_coords] = u_exact_func(corner_coords_tensor).item() + + def g0_multilinear_val(x: torch.Tensor) -> torch.Tensor: + num_points = x.shape[0] + xi = (x - min_bounds.to(x.device)) / (max_bounds.to(x.device) - min_bounds.to(x.device)) + xi = torch.clamp(xi, 0.0, 1.0) + g0_vals = torch.zeros((num_points, 1), device=x.device) + for corner_label in itertools.product([0, 1], repeat=domain_dim): + current_corner_key_list = [] + weight_factors = torch.ones((num_points, 1), device=x.device) + for d in range(domain_dim): + if corner_label[d] == 0: + current_corner_key_list.append(min_bounds[d].item()) + weight_factors *= (1 - xi[:, d]).unsqueeze(1) + else: + current_corner_key_list.append(max_bounds[d].item()) + weight_factors *= xi[:, d].unsqueeze(1) + corner_key_tuple = tuple(current_corner_key_list) + corner_value = boundary_values[corner_key_tuple] + g0_vals += corner_value * weight_factors + return g0_vals + return g0_multilinear_val + + else: + raise ValueError(f"Unknown g0_type: {g0_type}. Choose 'multilinear' or 'hermite_cubic_2nd_deriv'.") + +def _psi_tensor(t: torch.Tensor) -> torch.Tensor: + return torch.where(t <= 0, torch.tensor(0.0, dtype=t.dtype, device=t.device), torch.exp(-1.0 / t)) + +def get_d_func(domain_dim: int, domain_bounds: Union[Tuple[float, float], Tuple[Tuple[float, float], ...]], + d_type: str = "sin_half_period") -> Callable[[torch.Tensor], torch.Tensor]: + domain_bounds_tuple = domain_bounds + if domain_dim == 1 and not isinstance(domain_bounds[0], (tuple, list)): + domain_bounds_tuple = (domain_bounds,) + min_bounds = torch.tensor([b[0] for b in domain_bounds_tuple], dtype=torch.float32) + max_bounds = torch.tensor([b[1] for b in domain_bounds_tuple], dtype=torch.float32) + domain_length = (max_bounds[0] - min_bounds[0]).item() if domain_dim == 1 else None + + if d_type == "quadratic_bubble": + def d_func_val(x: torch.Tensor) -> torch.Tensor: + d_vals = torch.ones_like(x[:, 0], dtype=torch.float32, device=x.device) + for i in range(domain_dim): + x_i = x[:, i] + min_val, max_val = domain_bounds_tuple[i] + d_vals *= (x_i - min_val) * (max_val - x_i) + return d_vals.unsqueeze(1) + return d_func_val + + elif d_type == "inf_smooth_bump": + def d_inf_smooth_bump_val(x: torch.Tensor) -> torch.Tensor: + product_terms = torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) + for i in range(domain_dim): + x_i = x[:, i] + min_val_i = min_bounds[i] + max_val_i = max_bounds[i] + x_c_i = (min_val_i + max_val_i) / 2.0 + R_i = (max_val_i - min_val_i) / 2.0 + R_i_squared = R_i**2 + arg_for_psi = R_i_squared - (x_i - x_c_i)**2 + product_terms *= _psi_tensor(arg_for_psi) + return product_terms.unsqueeze(1) + return d_inf_smooth_bump_val + + elif d_type == "abs_dist_complement": + if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + def d_abs_dist_complement_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + x_norm = (x_val - min_bounds[0]) / domain_length + sqrt_term = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2) + return (1.0 - sqrt_term).unsqueeze(1) + return d_abs_dist_complement_val + + elif d_type == "ratio_bubble_dist": + if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + def d_ratio_bubble_dist_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + x_norm = (x_val - min_bounds[0]) / domain_length + numerator = x_norm * (1.0 - x_norm) + denominator = torch.sqrt(x_norm**2 + (1.0 - x_norm)**2) + return (numerator / denominator).unsqueeze(1) + return d_ratio_bubble_dist_val + + elif d_type == "sin_half_period": + if domain_dim != 1: raise ValueError(f"d_type '{d_type}' is only supported for 1D problems.") + if domain_length is None: raise ValueError("Domain length must be defined for 'sin_half_period' d_type.") + def d_sin_half_period_val(x: torch.Tensor) -> torch.Tensor: + x_val = x[:, 0] + argument = (torch.pi / domain_length) * (x_val - min_bounds[0]) + return torch.sin(argument).unsqueeze(1) + return d_sin_half_period_val + + else: + raise ValueError(f"Unknown d_type: {d_type}. Choose from 'quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', or 'sin_half_period'.") + # %% # Define PDE class PDE: @@ -180,23 +324,65 @@ class LevelStatus(Enum): TRAIN = "train" FROZEN = "frozen" +# %% +# Define multievel gates +class GatedLevel(nn.Module): + def __init__(self, level_idx, init_frozen, device="cuda"): + super().__init__() + self.level = level_idx + + if not init_frozen: + # Case 1: all gates = 1.0 and never trainable + self.gate = nn.Parameter(torch.tensor(1.0, device=device), requires_grad=False) + + else: + # Case 2: init_frozen = True + # level 0 → gate = 1.0 + # other levels → gate = 0.0 + init_value = 1.0 if level_idx == 0 else 0.0 + self.gate = nn.Parameter(torch.tensor(init_value, device=device), requires_grad=True) + + def freeze_gate(self): + """Make this level’s gate NOT trainable.""" + self.gate.requires_grad = False + + def unfreeze_gate(self): + """Make this level’s gate trainable.""" + self.gate.requires_grad = True + + def forward(self, x): + return self.gate * x # %% # Define multilevel NN class MultiLevelNN(nn.Module): - def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hidden: list, + def __init__(self, mesh: Mesh, num_levels: int, + dim_inputs, dim_outputs, dim_hidden: list, act: nn.Module = nn.ReLU(), enforce_bc: bool = False, g0_type: str = "multilinear", d_type: str = "sin_half_period", use_chebyshev_basis: bool = False, - chebyshev_freq_min: int = 0, - chebyshev_freq_max: int = 0) -> None: + chebyshev_freq_min: np.ndarray = None, + chebyshev_freq_max: np.ndarray = None, + init_frozen: bool = False) -> None: + """ + Multilevel NN with per-level scalar gates. + + Gate logic: + - if init_frozen == False: + * all gates = 1.0 and always non-trainable (requires_grad=False) + - if init_frozen == True: + * gate[0] = 1.0, gate[1:] = 0.0 + * gates are trainable only when corresponding level status == TRAIN + """ super().__init__() self.mesh = mesh - # currently the same model on each level self.dim_inputs = dim_inputs self.dim_outputs = dim_outputs self.enforce_bc = enforce_bc + self.use_chebyshev_basis = use_chebyshev_basis + self.init_frozen = init_frozen + # BC helpers (unchanged) self.g0_func = None self.d_func = None if self.enforce_bc: @@ -213,41 +399,81 @@ def __init__(self, mesh: Mesh, num_levels: int, dim_inputs, dim_outputs, dim_hid ) print(f"BCs will be enforced using g0_type: {g0_type} and d_type: {d_type}") - self.use_chebyshev_basis = use_chebyshev_basis - self.chebyshev_freqs = np.round(np.linspace(chebyshev_freq_min, chebyshev_freq_max, num_levels + 1)).astype(int) + # Build level subnets (same Level class as before) + assert chebyshev_freq_min is not None and chebyshev_freq_max is not None, \ + "chebyshev_freq_min/max must be provided (use -1 if unused)" self.models = nn.ModuleList([ Level(dim_inputs=dim_inputs, dim_outputs=dim_outputs, dim_hidden=dim_hidden, act=act, use_chebyshev_basis=use_chebyshev_basis, - chebyshev_freq_min=self.chebyshev_freqs[i], - chebyshev_freq_max=self.chebyshev_freqs[i+1]) + chebyshev_freq_min=chebyshev_freq_min[i], + chebyshev_freq_max=chebyshev_freq_max[i]) for i in range(num_levels) - ]) + ]) - # All levels start as "off" - self.level_status = [LevelStatus.OFF] * num_levels + # Level status initialization (keeps your original convention) + if init_frozen: + # all levels start as FROZEN (so OFF is different semantic) + self.level_status = [LevelStatus.FROZEN] * num_levels + else: + self.level_status = [LevelStatus.OFF] * num_levels - # No gradients are tracked initially + # Initially disable grads for model params (will be enabled when set_status(..., TRAIN)) for model in self.models: for param in model.parameters(): param.requires_grad = False - # Scale factor + # Per-level scalar gates as Parameters + # Use nn.ParameterList so gates are included in state_dict/parameters() + gates = [] + for i in range(num_levels): + if not init_frozen: + # gates fixed at 1.0, never trainable + val = 1.0 + gate = nn.Parameter(torch.tensor(float(val), dtype=torch.float32), requires_grad=False) + else: + # gate0 = 1.0, others = 0.0 + val = 1.0 + #val = 1.0 if i == 0 else 0.0 + # Initially gates are NOT trainable; they become trainable only when level set to TRAIN + gate = nn.Parameter(torch.tensor(float(val), dtype=torch.float32), requires_grad=False) + gates.append(gate) + self.gates = nn.ParameterList(gates) + + # Keep per-level input scales like original self.scales = [1.0] * num_levels + # ---------- status helpers ---------- def get_status(self, level_idx: int): if level_idx < 0 or level_idx >= self.num_levels(): raise IndexError(f"Level index {level_idx} is out of range") return self.level_status[level_idx] def set_status(self, level_idx: int, status: LevelStatus): + """Set level status and toggle requires_grad for model params and gate as required. + + Gate training logic: + - if self.init_frozen == False -> gates are always non-trainable (remain requires_grad=False) + - else (init_frozen == True) -> gate.requires_grad = (status == LevelStatus.TRAIN) + """ assert isinstance(status, LevelStatus), f"Invalid status: {status}" if level_idx < 0 or level_idx >= self.num_levels(): raise IndexError(f"Level index {level_idx} is out of range") + self.level_status[level_idx] = status - requires_grad = status == LevelStatus.TRAIN + requires_grad = (status == LevelStatus.TRAIN) + + # toggle level model parameters for param in self.models[level_idx].parameters(): param.requires_grad = requires_grad + # toggle gate requires_grad only when init_frozen is True + if self.init_frozen: + # gates train only when the corresponding level is TRAIN + self.gates[level_idx].requires_grad = requires_grad + else: + # init_frozen == False -> gates must remain non-trainable + self.gates[level_idx].requires_grad = False + def set_all_status(self, status_list: list[LevelStatus]): assert len(status_list) == len(self.models), "Length mismatch in status list" for i, status in enumerate(status_list): @@ -255,7 +481,9 @@ def set_all_status(self, status_list: list[LevelStatus]): def print_status(self): for i, status in enumerate(self.level_status): - print(f"Level {i}: {status.name}") + gate_val = float(self.gates[i].detach().cpu().numpy()) + gate_trainable = bool(self.gates[i].requires_grad) + print(f"Level {i}: {status.name}, gate={gate_val:.6f}, gate_trainable={gate_trainable}") def num_levels(self): return len(self.models) @@ -274,6 +502,7 @@ def set_all_scales(self, scale_list: list[float]): for i, scale in enumerate(scale_list): self.set_scale(i, scale) + # ---------- forward / solution ---------- def forward(self, x: torch.Tensor) -> torch.Tensor: ys = [] for i, model in enumerate(self.models): @@ -282,27 +511,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_scale = x else: x_scale = self.scales[i] * x - y = model.forward(x=x_scale) + y_raw = model.forward(x=x_scale) # shape: (batch, dim_outputs) + gate = self.gates[i] # scalar parameter + # Broadcast gate to y shape automatically + y = gate * y_raw ys.append(y) + if not ys: # No active levels, return zeros with correct shape return torch.zeros((x.shape[0], self.dim_outputs), device=x.device) - # Concatenate along the column (feature) dimension + out = torch.cat(ys, dim=1) assert out.shape[1] == self.num_active_levels() * self.dim_outputs return out - def get_solution(self, x: torch.Tensor) -> torch.Tensor: + def get_bubble(self, x: torch.Tensor) -> torch.Tensor: + return self.d_func(x) + + def get_bc_extension(self, x: torch.Tensor) -> torch.Tensor: + return self.g0_func(x) + + def get_model(self, x: torch.Tensor) -> torch.Tensor: y = self.forward(x) n_active = self.num_active_levels() - # reshape to [batch_size, num_levels, dim_outputs] - # and sum over levels + # reshape to [batch_size, num_levels, dim_outputs] and sum over levels if n_active > 1: y = y.view(-1, n_active, self.dim_outputs) y = y.sum(dim=1) # shape: (n, dim_outputs) - # + return y + + def get_solution(self, x: torch.Tensor) -> torch.Tensor: + y = self.get_model(x) if self.enforce_bc: g0_vals = self.g0_func(x) d_vals = self.d_func(x) @@ -314,13 +555,6 @@ def get_solution(self, x: torch.Tensor) -> torch.Tensor: return y - # def _init_weights(self, m): - # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - # nn.init.ones_(m.weight) - # m.bias.data.fill_(0.01) - # if type(m) == nn.Linear: - # torch.nn.init.xavier_uniform(m.weight) # - # %% # Define Loss class Loss: @@ -395,7 +629,7 @@ def loss(self, model, mesh): # %% # Define the training loop def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, num_check, num_plots, sweep_idx, - level_idx, frame_dir, scheduler_gen): + level_idx, frame_dir, scheduler_gen, track_freqs): optimizer = optim.Adam(model.parameters(), lr=learning_rate) # optimizer = SOAP(model.parameters(), lr = 3e-3, betas=(.95, .95), weight_decay=.01, # precondition_frequency=10) @@ -404,8 +638,62 @@ def train(model, mesh, criterion, iterations, adam_iterations, learning_rate, nu def to_np(t): return t.detach().cpu().numpy() + def closure(): + optimizer.zero_grad() + loss = criterion.loss(model=model, mesh=mesh) + loss.backward() + return loss + u_analytic = mesh.pde.u_ex(mesh.x_eval) - _, uf_analytic, _, _ = fourier_analysis(to_np(mesh.x_eval), to_np(u_analytic)) + xf_analytic, uf_analytic, _, _ = fourier_analysis(to_np(mesh.x_eval), to_np(u_analytic), model.enforce_bc) + if model.enforce_bc: + d_val = model.get_bubble(mesh.x_eval)[:, 0].unsqueeze(-1) + g_val = model.get_bc_extension(mesh.x_eval)[:, 0].unsqueeze(-1) + y_analytic = (u_analytic - g_val) / d_val + + tracked_data = [] + L2_err, H1_err, H2_err = error_analysis(mesh.x_eval, u_analytic, model) + model.eval() + with torch.no_grad(): + u_train = model.get_solution(mesh.x_train)[:, 0].unsqueeze(-1) + u_eval = model.get_solution(mesh.x_eval)[:, 0].unsqueeze(-1) + error = u_analytic - u_eval.to(u_analytic.device) + + xf_eval, uf_eval, uf_eval_real, uf_eval_imag = fourier_analysis(to_np(mesh.x_eval), to_np(u_eval), model.enforce_bc) + f_error = uf_analytic - uf_eval + + tracked_nn_coeffs = uf_eval[track_freqs] + tracked_true_coeffs = uf_analytic[track_freqs] + tracked_data.append({ + 'epoch': 0, + 'L2_error': L2_err, + 'H1_error': H1_err, + 'H2_error': H2_err, + 'nn_coeffs': tracked_nn_coeffs, + 'true_coeffs': tracked_true_coeffs + }) + + save_frame(x=xf_eval, t=uf_analytic, y=uf_eval, xs=None, ys=None, + iteration=[sweep_idx, level_idx, 0], title="Model_Frequencies", frame_dir=frame_dir) + save_frame(x=xf_eval, t=None, y=f_error, xs=None, ys=None, + iteration=[sweep_idx, level_idx, 0], title="Frequencies_Errors", frame_dir=frame_dir) + save_frame(x=to_np(mesh.x_eval), t=to_np(u_analytic), y=to_np(u_eval), + xs=to_np(mesh.x_train), ys=to_np(u_train), + iteration=[sweep_idx, level_idx, 0], title="Model_Outputs", frame_dir=frame_dir) + save_frame(x=to_np(mesh.x_eval), t=None, y=to_np(error), xs=None, ys=None, + iteration=[sweep_idx, level_idx, 0], title="Model_Errors", frame_dir=frame_dir) + if model.enforce_bc: + y_train = model.get_model(mesh.x_train)[:, 0].unsqueeze(-1) + y_eval = model.get_model(mesh.x_eval)[:, 0].unsqueeze(-1) + y_error = y_analytic - y_eval + save_frame(x=to_np(mesh.x_eval), t=to_np(y_analytic), y=to_np(y_eval), + xs=to_np(mesh.x_train), ys=to_np(y_train), + iteration=[sweep_idx, level_idx, 0], title="NN_Model_Outputs", frame_dir=frame_dir) + save_frame(x=to_np(mesh.x_eval), t=None, y=to_np(y_error), xs=None, ys=None, + iteration=[sweep_idx, level_idx, 0], title="NN_Model_Errors", frame_dir=frame_dir) + + model.train() + check_freq = (iterations + num_check - 1) // num_check plot_freq = (iterations + num_plots - 1) // num_plots if num_plots > 0 else 0 @@ -415,12 +703,6 @@ def to_np(t): return t.detach().cpu().numpy() optimizer = optim.LBFGS(model.parameters(), lr=learning_rate, max_iter=20, tolerance_grad=1e-7, history_size=100) - def closure(): - optimizer.zero_grad() - loss = criterion.loss(model=model, mesh=mesh) - loss.backward() - return loss - if use_lbfgs: loss = optimizer.step(closure) else: @@ -442,27 +724,68 @@ def closure(): with torch.no_grad(): u_eval = model.get_solution(mesh.x_eval)[:, 0].unsqueeze(-1) error = u_analytic - u_eval.to(u_analytic.device) - print(f"Iteration {i:6d}/{iterations:6d}, {criterion.name}: {loss.item():.4e}, " + print(f"Iteration {i+1:6d}/{iterations:6d}, {criterion.name}: {loss.item():.4e}, " f"Err 2-norm: {torch.norm(error): .4e}, " f"inf-norm: {torch.max(torch.abs(error)):.4e}") model.train() if plot_freq > 0 and (np.remainder(i + 1, plot_freq) == 0 or i == iterations - 1): + L2_err, H1_err, H2_err = error_analysis(mesh.x_eval, u_analytic, model) model.eval() with torch.no_grad(): u_train = model.get_solution(mesh.x_train)[:, 0].unsqueeze(-1) u_eval = model.get_solution(mesh.x_eval)[:, 0].unsqueeze(-1) error = u_analytic - u_eval.to(u_analytic.device) - xf_eval, uf_eval, _, _ = fourier_analysis(to_np(mesh.x_eval), to_np(u_eval)) + xf_eval, uf_eval, uf_eval_real, uf_eval_imag = fourier_analysis(to_np(mesh.x_eval), to_np(u_eval), model.enforce_bc) + f_error = uf_analytic - uf_eval + + tracked_nn_coeffs = uf_eval[track_freqs] + tracked_true_coeffs = uf_analytic[track_freqs] + tracked_data.append({ + 'epoch': i + 1, + 'L2_error': L2_err, + 'H1_error': H1_err, + 'H2_error': H2_err, + 'nn_coeffs': tracked_nn_coeffs, + 'true_coeffs': tracked_true_coeffs + }) + save_frame(x=xf_eval, t=uf_analytic, y=uf_eval, xs=None, ys=None, - iteration=[sweep_idx, level_idx, i], title="Model_Frequencies", frame_dir=frame_dir) + iteration=[sweep_idx, level_idx, i+1], title="Model_Frequencies", frame_dir=frame_dir) + save_frame(x=xf_eval, t=None, y=f_error, xs=None, ys=None, + iteration=[sweep_idx, level_idx, i+1], title="Frequencies_Errors", frame_dir=frame_dir) save_frame(x=to_np(mesh.x_eval), t=to_np(u_analytic), y=to_np(u_eval), xs=to_np(mesh.x_train), ys=to_np(u_train), - iteration=[sweep_idx, level_idx, i], title="Model_Outputs", frame_dir=frame_dir) + iteration=[sweep_idx, level_idx, i+1], title="Model_Outputs", frame_dir=frame_dir) save_frame(x=to_np(mesh.x_eval), t=None, y=to_np(error), xs=None, ys=None, - iteration=[sweep_idx, level_idx, i], title="Model_Errors", frame_dir=frame_dir) + iteration=[sweep_idx, level_idx, i+1], title="Model_Errors", frame_dir=frame_dir) + if model.enforce_bc: + y_train = model.get_model(mesh.x_train)[:, 0].unsqueeze(-1) + y_eval = model.get_model(mesh.x_eval)[:, 0].unsqueeze(-1) + y_error = y_analytic - y_eval + save_frame(x=to_np(mesh.x_eval), t=to_np(y_analytic), y=to_np(y_eval), + xs=to_np(mesh.x_train), ys=to_np(y_train), + iteration=[sweep_idx, level_idx, i+1], title="NN_Model_Outputs", frame_dir=frame_dir) + save_frame(x=to_np(mesh.x_eval), t=None, y=to_np(y_error), xs=None, ys=None, + iteration=[sweep_idx, level_idx, i+1], title="NN_Model_Errors", frame_dir=frame_dir) + model.train() + if track_freqs: + plot_error_evolution( + data=tracked_data, + sweep_idx=sweep_idx, + level_idx=level_idx, + frame_dir=frame_dir + ) + plot_coefficient_evolution( + data=tracked_data, + freqs=track_freqs, + sweep_idx=sweep_idx, + level_idx=level_idx, + frame_dir=frame_dir, + analytic_freqs=xf_analytic[track_freqs] + ) # %% # Define the main function @@ -475,25 +798,47 @@ def main(args=None): torch.manual_seed(0) # Parse args args = parse_args(args=args) - # Ensure chebyshev_freq_max is at least chebyshev_freq_min for range to be valid - if args.use_chebyshev_basis and args.chebyshev_freq_max < args.chebyshev_freq_min: - raise ValueError("chebyshev_freq_max must be >= chebyshev_freq_min when using Chebyshev basis.") print_args(args=args, output_file=f"results_pinn_1d_{ts}/args.txt") # PDE pde = PDE(high=args.high_freq, mu=args.mu, r=args.gamma, problem=args.problem_id) # Loss function [supervised with analytical solution (-1) or PINN loss (0)] - loss = Loss(loss_type=args.loss_type, bc_weight=args.bc_weight) - print(f"Using loss: {loss.name}") + losses = [] + losses.append(Loss(loss_type=-1, bc_weight=args.bc_weight)) + losses.append(Loss(loss_type=0, bc_weight=args.bc_weight)) + losses.append(Loss(loss_type=1, bc_weight=args.bc_weight)) + + if args.use_chebyshev_basis: + if len(args.chebyshev_freq_min) == 1: + chebyshev_freq_min = np.ones(args.levels, dtype=int) * args.chebyshev_freq_min + else: + chebyshev_freq_min = np.array(args.chebyshev_freq_min).astype(int) + print(f"Chebyshev frequencies lower bounds = {chebyshev_freq_min}") + + if len(args.chebyshev_freq_max) == 1: + chebyshev_freq_max = np.ones(args.levels, dtype=int) * args.chebyshev_freq_max + else: + chebyshev_freq_max = np.array(args.chebyshev_freq_max).astype(int) + print(f"Chebyshev frequencies upper bounds = {chebyshev_freq_max}") + else: + chebyshev_freq_min = np.ones(args.levels, dtype=int) * -1 + chebyshev_freq_max = np.ones(args.levels, dtype=int) * -1 + + if len(args.epochs) == 1: + epochs = np.ones(args.levels, dtype=int) * args.epochs + else: + epochs = np.array(args.epochs).astype(int) + # scheduler gen takes optimizer to return scheduler scheduler_gen = get_scheduler_generator(args) # 1-D mesh - mesh = Mesh(ntrain=args.nx, neval=args.nx_eval, ax=args.ax, bx=args.bx) + mesh = Mesh(ntrain=args.nx[-1], neval=args.nx_eval, ax=args.ax, bx=args.bx) mesh.set_pde(pde=pde) # Create an instance of multilevel model # Input and output dimension: x -> u(x) dim_inputs = 1 dim_outputs = 1 + model = MultiLevelNN(mesh=mesh, num_levels=args.levels, dim_inputs=dim_inputs, dim_outputs=dim_outputs, @@ -503,8 +848,9 @@ def main(args=None): g0_type=args.bc_extension, d_type=args.distance, use_chebyshev_basis=args.use_chebyshev_basis, - chebyshev_freq_min=args.chebyshev_freq_min, - chebyshev_freq_max=args.chebyshev_freq_max) + chebyshev_freq_min=chebyshev_freq_min, + chebyshev_freq_max=chebyshev_freq_max, + init_frozen=args.init_frozen) print(model) model.to(device) # Plotting @@ -530,10 +876,24 @@ def main(args=None): scale = lev + 1 model.set_scale(level_idx=lev, scale=scale) # Crank that !@#$ up - train(model=model, mesh=mesh, criterion=loss, iterations=args.epochs, + if args.loss_type < 2: + loss = losses[args.loss_type+1] + else: + if lev == 0: # DRM + loss = losses[2] + else: # PINN + loss = losses[1] + if len(args.lr) > 1: + lr = args.lr[lev] + else: + lr = args.lr[0] + if len(args.nx) > 1: + mesh = Mesh(ntrain=args.nx[lev], neval=args.nx_eval, ax=args.ax, bx=args.bx) + mesh.set_pde(pde=pde) + train(model=model, mesh=mesh, criterion=loss, iterations=epochs[lev], adam_iterations=args.adam_epochs, - learning_rate=args.lr, num_check=args.num_checks, num_plots=num_plots, - sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen) + learning_rate=lr, num_check=args.num_checks, num_plots=num_plots, + sweep_idx=i, level_idx=lev, frame_dir=frame_dir, scheduler_gen=scheduler_gen, track_freqs=args.track_freqs) # Turn PNGs into a video using OpenCV if args.plot: make_video_from_frames(frame_dir=frame_dir, name_prefix="Model_Outputs", @@ -542,6 +902,13 @@ def main(args=None): output_file="Errors.mp4") make_video_from_frames(frame_dir=frame_dir, name_prefix="Model_Frequencies", output_file="Frequencies.mp4") + make_video_from_frames(frame_dir=frame_dir, name_prefix="Frequencies_Errors", + output_file="Frequencies_Errors.mp4") + if args.enforce_bc: + make_video_from_frames(frame_dir=frame_dir, name_prefix="NN_Model_Outputs", + output_file="NN_Solution.mp4") + make_video_from_frames(frame_dir=frame_dir, name_prefix="NN_Model_Errors", + output_file="NN_Errors.mp4") return 0 diff --git a/pinn/utils.py b/pinn/utils.py index 1828149..5fcd79f 100644 --- a/pinn/utils.py +++ b/pinn/utils.py @@ -18,7 +18,7 @@ import matplotlib.pyplot as plt import cv2 from pathlib import Path -from scipy.fft import rfft, rfftfreq +from scipy.fft import rfft, rfftfreq, dst import numpy as np import torch import ast @@ -48,7 +48,7 @@ def is_notebook(): def parse_args(args=None): parser = argparse.ArgumentParser(description="Train a PINN model.") - parser.add_argument('--nx', type=int, default=128, + parser.add_argument('--nx', type=int, nargs='+', default=[128], help="Number of training points in the 1D mesh.") parser.add_argument('--nx_eval', type=int, default=256, help="Number of evaluation points in the 1D mesh.") @@ -56,7 +56,7 @@ def parse_args(args=None): help="Number of evaluation checkpoints during training.") parser.add_argument('--num_plots', type=int, default=10, help="Number of plotting points during training.") - parser.add_argument('--epochs', type=int, default=10000, + parser.add_argument('--epochs', type=int, nargs='+', default=[10000], help="Number of training epochs per sweep.") parser.add_argument('--adam_epochs', type=int, default=None, help="Number of training epochs using Adam per sweep. Defaults to --epochs if not set.") @@ -74,33 +74,39 @@ def parse_args(args=None): help="Coefficient γ in the PDE: -uₓₓ + γ u = f.") parser.add_argument('--mu', type=float, default=70, help="Oscillation parameter in the solution (PDE 2).") - parser.add_argument('--lr', type=float, default=1e-3, + parser.add_argument('--lr', type=float, nargs='+', default=[1e-3], help="Learning rate for the optimizer.") parser.add_argument('--levels', type=int, default=4, help="Number of levels in multilevel training.") - parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1], - help="Loss type: -1 for supervised (true solution), 0 for PINN loss.") + parser.add_argument('--init_frozen', action='store_true', + help="If set, use frozen in higher levels as initial.") + parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0, 1, 2], + help="Loss type: -1 for supervised (true solution), 0 for PINN loss, 1 for DRM loss, 2 for mixed.") parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'], help="Activation function to use.") + parser.add_argument('--enforce_bc', action='store_true', + help="If set, enforce the BC in solution.") parser.add_argument('--bc_extension', type=str, default='hermite_cubic_2nd_deriv', choices=['multilinear', 'hermite_cubic_2nd_deriv'], help='Boundary value extension function.') parser.add_argument('--distance', type=str, default='sin_half_period', choices=['quadratic_bubble', 'inf_smooth_bump', 'abs_dist_complement', 'ratio_bubble_dist', 'sin_half_period'], help='Distance function.') - parser.add_argument('--chebyshev_freq_min', type=int, default=-1, + parser.add_argument('--use_chebyshev_basis', action='store_true', + help="If set, use Chebyshev features.") + parser.add_argument('--chebyshev_freq_min', type=int, nargs='+', help='Minimum frequency for Chebyshev polynomials.') - parser.add_argument('--chebyshev_freq_max', type=int, default=-1, + parser.add_argument('--chebyshev_freq_max', type=int, nargs='+', help='Maximum frequency for Chebyshev polynomials.') + parser.add_argument('--track_freqs', type=int, nargs='+', default=[0, 1, 2, 3, 4, 5, 6, 7], + help="Integer array of frequencies (modes) whose coefficients will be tracked and plotted over epochs.") parser.add_argument('--plot', action='store_true', help="If set, generate plots during or after training.") parser.add_argument('--no-clear', action='store_false', dest='clear', help="If set, do not remove plot files generated before.") parser.add_argument('--problem_id', type=int, default=1, choices=[1, 2], help="PDE problem to solve: 1 or 2.") - parser.add_argument('--enforce_bc', action='store_true', - help="If set, enforce the BC in solution.") parser.add_argument('--bc_weight', type=float, default=1.0, help="Weight for the loss of BC.") parser.add_argument("--scheduler", type=str, default="StepLR", @@ -118,12 +124,6 @@ def parse_args(args=None): if args.adam_epochs is None: args.adam_epochs = args.epochs - if (1 <= args.chebyshev_freq_min <= args.chebyshev_freq_max): - print(f"Chebyshev basis of frequency {args.chebyshev_freq_min} to {args.chebyshev_freq_max} are used") - args.use_chebyshev_basis = True - else: - args.use_chebyshev_basis = False - return args @@ -200,6 +200,65 @@ def get_activation(name: str): raise ValueError(f"Unknown activation function: {name}") return activations[name]() +def plot_coefficient_evolution(data: list, freqs: list, sweep_idx: int, level_idx: int, frame_dir: str, analytic_freqs: np.ndarray): + """ + Plots the evolution of specific Fourier/Sine coefficients over training epochs. + """ + if not data: + print("No coefficient data collected for plotting.") + return + + # Extract all data into a structured format + epochs = np.array([d['epoch'] for d in data]) + true_coeffs = np.array([d['true_coeffs'] for d in data]) + nn_coeffs = np.array([d['nn_coeffs'] for d in data]) + + num_freqs = len(freqs) + + fig1, ax1 = plt.subplots(figsize=(10, 6)) + ax1.set_title(f"Sweep {sweep_idx}, Level {level_idx}: Fourier Coefficient Evolution") + ax1.set_xlabel("Epoch") + ax1.set_ylabel("Coefficient Magnitude") + + for i in range(num_freqs): + # Plot true coefficient (should be constant) + ax1.plot(epochs, true_coeffs[:, i], + label=f"True (Freq = {freqs[i]} pi)", + linestyle='--', alpha=0.7) + # Plot NN coefficient evolution + ax1.plot(epochs, nn_coeffs[:, i], + label=f"NN (Freq = {freqs[i]} pi)", + linestyle='-') + + ax1.legend(loc='best') + + iters_str = f"Sweep{sweep_idx:02d}_Lvl{level_idx:02d}" + filename1 = os.path.join(frame_dir, f"Coeffs_Evolution_{iters_str}.png") + fig1.savefig(filename1) + plt.close(fig1) + + print(f" Coefficient evolution plot saved to {filename1}") + + fig2, ax2 = plt.subplots(figsize=(10, 6)) + ax2.set_title(f"Sweep {sweep_idx}, Level {level_idx}: Coefficient Error Evolution") + ax2.set_xlabel("Epoch") + ax2.set_ylabel("Absolute Error (|True - NN|)") + + error_coeffs = np.abs(true_coeffs - nn_coeffs) + + for i in range(num_freqs): + ax2.plot(epochs, error_coeffs[:, i], + label=f"Freq = {freqs[i]} pi", + linestyle='-') + + ax2.legend(loc='best') + ax2.set_yscale('log') + + filename2 = os.path.join(frame_dir, f"Coeffs_Error_Evolution_{iters_str}.png") + fig2.savefig(filename2) + plt.close(fig2) + + print(f" Coefficient error plot saved to {filename2}") # %% def save_frame(x, t, y, xs, ys, iteration, title, frame_dir): @@ -253,7 +312,7 @@ def make_video_from_frames(frame_dir, name_prefix, output_file, fps=10): # %% -def fourier_analysis(x, y): +def fourier_analysis(x, y, sine_series: bool = False): """ Compute the magnitude spectrum using the Fast Fourier Transform (FFT). Ref: https://docs.scipy.org/doc/scipy/tutorial/fft.html @@ -272,12 +331,141 @@ def fourier_analysis(x, y): N = len(x) # Sampling interval Ts = dx[0] - yf = rfft(y) - xf = rfftfreq(N, Ts) - yf *= 2.0 / N - # Correct scaling for DC and Nyquist (they should not be doubled) - yf[0] /= 2 - if N % 2 == 0: - yf[-1] /= 2 - - return xf, np.abs(yf), np.real(yf), -np.imag(yf) + + if sine_series: + yf = dst(y, type=1) + yf /= (N + 1) + yf = yf[:N-1] + L = N * Ts + xf = np.arange(1, N + 1) * (np.pi / L) + xf = xf[:N-1] + yf_imag = np.zeros_like(yf) + return xf, np.abs(yf), yf, yf_imag + else: + yf = rfft(y) + xf = rfftfreq(N, Ts) + yf *= 2.0 / N + # Correct scaling for DC and Nyquist (they should not be doubled) + yf[0] /= 2 + if N % 2 == 0: + yf[-1] /= 2 + return xf, np.abs(yf), np.real(yf), -np.imag(yf) + +# %% +def error_analysis(x: torch.Tensor, u_true: torch.Tensor, model: torch.nn.Module) -> dict: + """ + Calculates the L2, H1, and H2 relative errors of the NN solution and its derivatives + against the true solution and its derivatives. Derivatives are calculated within + the routine using finite differences for the true solution and automatic + differentiation for the NN solution. + + Args: + x (torch.Tensor): The 1D mesh points (must be uniformly spaced). + u_true (torch.Tensor): True solution values at x. + model (torch.nn.Module): Neural Network solution. + + Returns: + dict: Dictionary containing L2, H1, and H2 relative errors. + """ + + # Ensure all inputs are column vectors (N, 1) and on the same device/dtype + x = x.flatten().unsqueeze(-1).clone().detach().requires_grad_(True) + u_true = u_true.flatten().unsqueeze(-1) + u_nn = model.get_solution(x)[:, 0].unsqueeze(-1) + + # Compute first derivative (u'_nn) + u_prime_nn_and_rest = torch.autograd.grad( + outputs=u_nn, + inputs=x, + grad_outputs=torch.ones_like(u_nn), + create_graph=True, + retain_graph=True + ) + u_prime_nn = u_prime_nn_and_rest[0] + + # Compute second derivative (u''_nn) + u_double_prime_nn_and_rest = torch.autograd.grad( + outputs=u_prime_nn, + inputs=x, + grad_outputs=torch.ones_like(u_prime_nn), + create_graph=False + ) + u_double_prime_nn = u_double_prime_nn_and_rest[0] + + # Convert to NumPy for finite difference calculation + x_np = x.detach().cpu().numpy().flatten() + u_true_np = u_true.detach().cpu().numpy().flatten() + + # Use central finite difference (or second-order difference) + # The domain is assumed to be uniformly sampled based on the existing script's fourier_analysis. + + # First derivative (u'_true): gradient is a simple NumPy finite difference + u_prime_true_np = np.gradient(u_true_np, x_np, edge_order=2) + + # Second derivative (u''_true): gradient of the first derivative + u_double_prime_true_np = np.gradient(u_prime_true_np, x_np, edge_order=2) + + # Convert back to Torch Tensors + u_prime_true = torch.from_numpy(u_prime_true_np).float().to(u_true.device).unsqueeze(-1) + u_double_prime_true = torch.from_numpy(u_double_prime_true_np).float().to(u_true.device).unsqueeze(-1) + + # Relative L2 Error (u) + L2_error_num = torch.linalg.norm(u_true - u_nn, ord=2) + L2_error_den = torch.linalg.norm(u_true, ord=2) + L2_relative_error = (L2_error_num / L2_error_den).item() + + # Relative H1 Error (u and u') + H1_error_num = torch.linalg.norm(u_prime_true - u_prime_nn, ord=2) + H1_error_den = torch.linalg.norm(u_prime_true, ord=2) + H1_relative_error = (H1_error_num/ H1_error_den).item() + + # Relative H2 Error (u, u', and u'') + H2_error_den = torch.linalg.norm(u_double_prime_true, ord=2) + H2_error_num = torch.linalg.norm(u_double_prime_true - u_double_prime_nn, ord=2) + H2_relative_error = (H2_error_num / H2_error_den).item() + + return L2_relative_error, H1_relative_error, H2_relative_error + +# %% +def plot_error_evolution(data: list, sweep_idx: int, level_idx: int, frame_dir: str): + """ + Plots the evolution of L2, H1, and H2 relative errors over training epochs. + + Args: + data (list): List of dictionaries, each containing 'epoch' and error metrics. + sweep_idx (int): Current sweep index for file naming. + level_idx (int): Current level index for file naming. + frame_dir (str): Directory to save the plot. + """ + if not data: + print("No error data collected for plotting.") + return + + # Extract all data into structured numpy arrays + epochs = np.array([d['epoch'] for d in data]) + l2_errors = np.array([d['L2_error'] for d in data]) + h1_errors = np.array([d['H1_error'] for d in data]) + h2_errors = np.array([d['H2_error'] for d in data]) + + fig, ax = plt.subplots(figsize=(10, 6)) + + # Plot the three error metrics + ax.plot(epochs, l2_errors, label="L2 Relative Error", linestyle='-', marker='o') + ax.plot(epochs, h1_errors, label="H1 Relative Error", linestyle='--', marker='s') + ax.plot(epochs, h2_errors, label="H2 Relative Error", linestyle=':', marker='^') + + ax.set_title(f"Sweep {sweep_idx}, Level {level_idx}: Solution Error Evolution") + ax.set_xlabel("Epoch") + ax.set_ylabel("Relative Error (Log Scale)") + + # Set the y-axis to a logarithmic scale, as errors typically span several orders of magnitude + ax.set_yscale('log') + ax.legend(loc='best') + ax.grid(True, which="both", ls="--", linewidth=0.5) + + iters_str = f"Sweep{sweep_idx:02d}_Lvl{level_idx:02d}" + filename = os.path.join(frame_dir, f"Error_Evolution_{iters_str}.png") + fig.savefig(filename) + plt.close(fig) + + print(f" Error evolution plot saved to {filename}")