@@ -57,6 +57,51 @@ def _resolve_device(config: TorchTrainingConfig) -> torch.device:
5757 return torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5858
5959
60+ _SMALL_VALIDATION_WARNING_THRESHOLD = 10
61+
62+
63+ def _warn_on_small_validation_set (
64+ * ,
65+ n_val : int ,
66+ use_scheduler : bool ,
67+ save_best : bool ,
68+ ) -> None :
69+ """
70+ Warn when validation-driven controls are enabled on a tiny split.
71+
72+ Parameters
73+ ----------
74+ n_val : int
75+ Number of validation structures.
76+ use_scheduler : bool
77+ Whether ReduceLROnPlateau monitoring is enabled for this run.
78+ save_best : bool
79+ Whether best-checkpoint selection is enabled for this run.
80+ """
81+ if n_val <= 0 or n_val >= _SMALL_VALIDATION_WARNING_THRESHOLD :
82+ return
83+
84+ noun = "structure" if n_val == 1 else "structures"
85+
86+ if use_scheduler :
87+ warnings .warn (
88+ "use_scheduler=True with a validation set of only "
89+ f"{ n_val } { noun } can make ReduceLROnPlateau react to noisy "
90+ "metrics. Consider use_scheduler=False, a larger validation "
91+ "split, or an explicit train/test split." ,
92+ UserWarning ,
93+ )
94+
95+ if save_best :
96+ warnings .warn (
97+ "save_best=True with a validation set of only "
98+ f"{ n_val } { noun } can select a checkpoint from a noisy "
99+ "validation loss. Consider save_best=False, a larger "
100+ "validation split, or an explicit train/test split." ,
101+ UserWarning ,
102+ )
103+
104+
60105def _iter_progress (iterable , enable : bool , desc : str ):
61106 """
62107 Wrap an iterable with tqdm progress bar if enabled and available.
@@ -788,6 +833,15 @@ def train(
788833 else None
789834 )
790835
836+ n_val = int (len (test_ds )) if test_ds is not None else 0
837+ _warn_on_small_validation_set (
838+ n_val = n_val ,
839+ use_scheduler = bool (config .use_scheduler ) and (test_loader is not None ),
840+ save_best = bool (config .save_best )
841+ and (config .checkpoint_dir is not None )
842+ and (test_loader is not None ),
843+ )
844+
791845 # Initialize normalization manager
792846 normalize_features = bool (getattr (config , "normalize_features" , True ))
793847 normalize_energy = bool (getattr (config , "normalize_energy" , True ))
0 commit comments