Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e39c5c4
Clean up fgn example for public PR
kashif May 21, 2026
5e78ab2
Address reviewer feedback: rollout bug, dead amp field, _step_ensembl…
kashif May 21, 2026
12b6f6a
Contribution standards: ruff clean, CHANGELOG, FGNUNet docstring
kashif May 21, 2026
65e6339
Add energy score metric, Hydra config files, and fgn_arco training co…
kashif May 21, 2026
1b9a5ed
Add missing datasets package and utils/loss module
kashif May 21, 2026
db38462
Fix training loop and add bf16 AMP for memory-efficient training
kashif May 21, 2026
3ca1c1a
train_fgn.sh: set batch_size=2 (1 per GPU) for 2-GPU DDP run
kashif May 21, 2026
c141dea
FGN.md: update progress to 2026-05-21, tick completed tasks
kashif May 21, 2026
aff6648
Untrack FGN.md — local planning file, not for public repo
kashif May 21, 2026
443faea
fgn/README.md: rewrite to match examples/weather conventions
kashif May 21, 2026
3ec670f
fgn/README.md: add WeatherNext 2 references and production specs
kashif May 21, 2026
57e50ef
Merge branch 'main' into fgn
kashif May 22, 2026
da747c5
fgn/config.py: add EvalConfig and EvalMainConfig dataclasses
kashif May 22, 2026
dbd8ff3
fgn/metrics.py: add pooled_crps_per_lead (avg+max) and heatmap plot h…
kashif May 22, 2026
edd05d7
fgn/config: add eval_fgn.yaml for standalone §4 evaluation
kashif May 22, 2026
be93061
fgn: add eval.py — paper §4 standalone evaluation script
kashif May 22, 2026
41bb3fd
fgn/scripts: add eval_fgn.sh SLURM launcher for standalone eval
kashif May 22, 2026
9c4c0c3
fgn/.gitignore: stop tracking SLURM launcher scripts (cluster-specific)
kashif May 22, 2026
6d165c9
fgn/utils/metrics.py: fix plot layout and axis labels
kashif May 22, 2026
6116bcc
fgn/utils/metrics.py: fix subplot title overlap in multi-panel plots
kashif May 22, 2026
4bc26da
fgn: match paper Figure 3 plot layouts
kashif May 22, 2026
a834826
fgn: replace noisy line plots with CRPS scorecard heatmap
kashif May 22, 2026
797d9cd
fgn/eval: match paper Figure 2-3 plot layouts
kashif May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for Batched radius search, which enables Domino
and GeoTransolver with local features and batch size > 1.
- Added the underfill recipe.
- Adds Functional Generative Networks (FGN) weather training example
(`examples/weather/fgn`). Implements the latent-conditioned U-Net
stochastic generator from
`arXiv:2506.10772 <https://arxiv.org/abs/2506.10772>`_ (WeatherNext 2)
as a PhysicsNeMo ``Module``, trained with fair-CRPS loss on ERA5 via the
earth2studio ARCO data source. Supports autoregressive rollout training
with per-channel normalization, FSDP + ShardTensor domain parallelism,
deep-ensemble inference (paper §2.2.1), and validation diagnostics
(CRPS, RMSE, spread-skill, rank histograms, power spectra).

### Changed

Expand Down
12 changes: 12 additions & 0 deletions examples/weather/fgn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
*.mlus
*.png
*.pt
*.tfevents*
*wandb/
rundir/
logs/
*.npz
FGN.md
# SLURM launcher scripts are cluster-specific; keep locally, don't track.
scripts/train_fgn.sh
scripts/eval_fgn.sh
266 changes: 266 additions & 0 deletions examples/weather/fgn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
<!-- markdownlint-disable -->
# Functional Generative Networks for Weather Forecasting

A PhysicsNeMo implementation of Functional Generative Networks (FGN) for
probabilistic global weather forecasting, following the approach of:

> Alet et al., "Skillful joint probabilistic weather forecasting from marginals"
> ([arXiv:2506.10772](https://arxiv.org/abs/2506.10772))

FGN is the architecture behind the production
[WeatherNext 2](https://developers.google.com/weathernext/guides/models) model,
which delivers 64-member ensemble forecasts (4 independently trained seeds ×
16 trajectories each) at 0.25° global resolution. The full variable schema is
described in the
[WeatherNext 2 model specs](https://developers.google.com/weathernext/guides/model-specs-vmg).

FGN generates ensemble weather forecasts by perturbing a deterministic backbone
with a low-dimensional latent noise vector `z ~ N(0, I_32)` injected through
conditional layer normalization (CLN) at every layer, producing globally coherent
ensemble spread from a marginal (fair-CRPS) training loss. Multiple independently
trained model seeds form a deep ensemble (J=4 seeds, 16 trajectories each = 64
members in production) capturing both aleatoric and epistemic uncertainty.

## Problem Overview

FGN autoregressively predicts the next 6-hour atmospheric state from the two
previous states (`X_{t-2}`, `X_{t-1}`), sampled from ERA5 (pre-training) and
HRES-fc0 (fine-tuning) at 0.25° global resolution. Each forward pass is
non-diffusive: one pass per forecast step, with a fresh `z` drawn per step per
ensemble member. AR fine-tuning with rollouts up to 8 steps (Table A.2) improves
temporal coherence without requiring a diffusion sampler.

This example implements:

- Latent-conditioned `FGNUNet` backbone (`utils/nn.py`) with AdaGN modulation
- ARCO-backed real dataset using `earth2studio.data.ARCO` (`datasets/arco.py`)
- Fair-CRPS training loss with paper-faithful per-variable and area weights (`utils/loss.py`)
- Autoregressive rollout training with BPTT (`utils/trainer.py`)
- Multi-stage AR schedule runner (Table A.2: `8k·1AR → 4k·2AR → 1k·{3..8}AR`)
- Validation metrics and plots: CRPS, RMSE, spread-skill, rank histograms,
power spectra (`utils/metrics.py`)
- FSDP + ShardTensor distributed training via `ParallelHelper` (`utils/parallel.py`)
- Deep ensemble inference across multiple independently trained checkpoints
- Per-channel normalization stats with Welford online estimation

## Dataset

Training data is fetched live from the [ARCO ERA5](https://cloud.google.com/storage/docs/public-datasets/era5)
dataset via `earth2studio.data.ARCO`. No local download is required for training.

The dataset covers the full 83-channel Table A.1 schema: 78 atmospheric channels
(6 variables × 13 pressure levels: 50–1000 hPa) plus 5 input/predicted surface
channels (`t2m`, `u10m`, `v10m`, `msl`, `sst`) and `tp06` (6-h accumulated
precipitation, predicted-only). Static inputs (surface geopotential, land-sea mask)
and clock features (local time, year progress sin/cos) are added automatically.

All variables use compact Earth2Studio / PhysicsNeMo names: `u10m`, `v10m`, `t2m`,
`msl`, `sst`, `tp06`, `z{level}`, `q{level}`, `t{level}`, `u{level}`, `v{level}`,
`w{level}`.

> **Note:** The production WeatherNext 2 output also includes `u100m` / `v100m`
> (100 m wind components). ERA5 via ARCO does not provide 100 m winds, so they
> are omitted from this ERA5-based training example.

### Normalization Stats

Pre-compute per-channel mean and standard deviation before training:

```bash
python scripts/compute_arco_stats.py \
--start 2020-01-01 --end 2023-12-31 \
--output rundir/fgn_2024_val/stats_2024.npz
```

Pass the resulting `.npz` file to the trainer via `dataset.stats_path`.

## Getting Started

### Requirements

```bash
pip install -r requirements.txt
```

PyTorch 2.10 or higher is required for domain parallelism.

### Smoke Test

Run the self-contained synthetic test suite (no GPU, no network access):

```bash
pytest test_training.py
```

Multi-GPU tests require `torchrun`:

```bash
torchrun --standalone --nproc_per_node=2 --no-python pytest test_training.py
```

## Configuration

Training is configured with [Hydra](https://hydra.cc) and validated with Pydantic
(`utils/config.py`). Configs live under `config/`:

- `config/fgn.yaml` — base defaults (model, training, dataset structure)
- `config/fgn_arco.yaml` — ARCO real-data training config (inherits from `fgn.yaml`)
- `config/test_fgn.yaml` — fast synthetic smoke-test config

Key config knobs:

| Setting | Description |
|---|---|
| `model.hidden_channels` | U-Net channel width (64 for quick runs, 256+ for full scale) |
| `model.latent_dim` | Latent noise dimension (32, per paper) |
| `training.batch_size` | Global batch size; local per-GPU = `batch_size / data_parallel_size` |
| `training.ar_steps` | AR rollout length for loss (1 = single-step pre-training) |
| `training.loss.num_samples` | Ensemble members per training example (N=2 per paper) |
| `training.domain_parallel_size` | GPUs per sample for domain parallelism (1 = pure DDP) |
| `dataset.stats_path` | Path to `.npz` normalization stats |

Training outputs (checkpoints, logs, plots) are saved to:

```
rundir/{training.experiment_name}/{training.run_id}/
```

## Training

### Single GPU

```bash
python train.py --config-name fgn_arco \
dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
training.experiment_name=fgn_run \
training.batch_size=1
```

### Multi-GPU (torchrun)

```bash
torchrun --standalone --nnodes=1 --nproc_per_node=2 \
train.py --config-name fgn_arco \
dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
training.experiment_name=fgn_run \
training.batch_size=2
```

With 2 GPUs and `domain_parallel_size=1` (DDP), `batch_size` is the global batch
size — each GPU processes `batch_size / 2` samples.

### SLURM

```bash
sbatch scripts/train_fgn.sh
```

Override defaults via environment variables:

```bash
sbatch --export=ALL,EXP_NAME=fgn_2024,RUN_ID=1,STEPS=10000 scripts/train_fgn.sh
```

See `scripts/train_fgn.sh` for all overridable variables (`EXP_NAME`, `RUN_ID`,
`STEPS`, `CFG`, `STATS_PATH`, `NGPU`).

### Resuming

Set `training.resume_checkpoint=latest` (default) to automatically resume from
the most recent checkpoint in the run directory.

### Domain Parallelism

For models too large to fit one sample on a single GPU, enable domain parallelism:

```bash
torchrun --standalone --nproc_per_node=4 \
train.py --config-name fgn_arco \
training.domain_parallel_size=2 \
training.batch_size=2
```

With `domain_parallel_size=2` and 4 GPUs: 2 domain-parallel pairs, each handling
1 sample (`batch_size / data_parallel_size = 2 / 2 = 1`).

### AR Fine-Tuning Schedule (Table A.2)

The trainer implements the paper's multi-stage AR schedule automatically when
`training.ar_steps` increases across runs. Start with single-step pre-training,
then resume with progressively longer rollouts:

| Stage | `ar_steps` | Steps | Notes |
|---|---|---|---|
| 1 | 1 | 8000 | Single-step pre-train |
| 2 | 2 | 4000 | Resume from stage 1 |
| 3–8 | 3–8 | 1000 each | Resume from previous |

## Inference

Run stochastic ensemble inference from a trained checkpoint:

```bash
python inference.py --config-name inference_fgn \
inference.checkpoint=rundir/fgn_run/0/checkpoints/FGNUNet.mdlus
```

For deep ensemble inference across multiple independently trained seeds:

```bash
python inference.py --config-name inference_fgn \
"inference.checkpoints=[seed0/FGNUNet.mdlus, seed1/FGNUNet.mdlus, seed2/FGNUNet.mdlus, seed3/FGNUNet.mdlus]"
```

Trajectories are distributed across checkpoints following paper §2.2.1.

### Bad-Seed Detection

Before including a checkpoint in a deep ensemble, check its spectral properties:

```bash
python scripts/check_spectra.py \
--checkpoint rundir/fgn_run/0/checkpoints/FGNUNet.mdlus \
--stats rundir/fgn_2024_val/stats_2024.npz
```

## Adding Custom Datasets

Implement the `FGNDataset` interface from `datasets/dataset.py`:

```python
class MyDataset(FGNDataset):
def state_channels(self) -> list[str]: ...
def background_channels(self) -> list[str]: ...
def image_shape(self) -> tuple[int, int]: ...
def __len__(self) -> int: ...
def __getitem__(self, idx): ...
# Optional:
def get_invariants(self) -> np.ndarray | None: ...
def output_only_channels(self) -> list[int]: ...
```

`__getitem__` should return a dict with keys `history` (shape `(T, C, H, W)`),
`target` (shape `(K, C, H, W)`), and optionally `background`. Register your
dataset by placing it in `datasets/` — it is discovered automatically via
`pkgutil.iter_modules` at import time.

## Memory Management

At 0.25° (721×1440), each training sample is large. Recommended settings for an
80 GB H100:

- `training.batch_size=2` (1 per GPU) with 2 GPUs, `domain_parallel_size=1`
- bf16 AMP is enabled automatically (`torch.autocast(bfloat16)`)
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (set in `train_fgn.sh`)

For larger models (hidden_channels ≥ 256), use `domain_parallel_size=2` with
4+ GPUs, or enable gradient checkpointing via `model.checkpoint_level`.

## References

- [Skillful joint probabilistic weather forecasting from marginals](https://arxiv.org/abs/2506.10772)
- [WeatherNext 2 model overview](https://developers.google.com/weathernext/guides/models)
- [WeatherNext 2 variable schema](https://developers.google.com/weathernext/guides/model-specs-vmg)
- [Generative Ensemble Downscaling with Diffusion Models (CorrDiff)](https://arxiv.org/abs/2308.14453)
- [Kilometer-Scale Convection Allowing Model Emulation (StormCast)](https://arxiv.org/abs/2408.10958)
- [GraphCast: Learning skillful medium-range global weather forecasting](https://arxiv.org/abs/2212.12794)
71 changes: 71 additions & 0 deletions examples/weather/fgn/config/eval_fgn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Standalone evaluation config for FGN (§4 of arXiv:2506.10772).
#
# Usage:
# python eval.py --config-name eval_fgn \
# dataset.stats_path=rundir/fgn_2024_val/stats_2024.npz \
# eval.checkpoint=rundir/fgn_2024_long/0/checkpoints/FGNUNet.0.5000.mdlus
#
# To evaluate a deep ensemble pass a list:
# eval.checkpoints=[seed0/FGNUNet.mdlus,seed1/FGNUNet.mdlus]

defaults:
- fgn # model + dataset schema defaults
- _self_

dataset:
name: arco.ArcoFGNDataset
state_variables: null
invariant_variables:
- z
- lsm
step_hours: 6
history_frames: 2
# future_frames is overridden by eval.future_steps at runtime
future_frames: ${eval.future_steps}
val_start: "2024-10-01"
val_end: "2025-01-01"
# train_{start,end} unused in eval but required by the dataset schema
train_start: "2024-01-01"
train_end: "2024-10-01"
spatial_stride: 1
static_date: "2016-01-01"
arco_cache: true
stats_path: ???
tp_accumulation_hours: null

model:
latent_dim: 16
hidden_channels: 64

# training section is required by TrainMainConfig but unused during eval;
# keep it minimal so Pydantic passes validation.
training:
experiment_name: fgn_eval
run_id: "0"
batch_size: 1
total_train_steps: 1
ar_steps: ${eval.future_steps}

eval:
# Path to a single .mdlus checkpoint, or "latest" to auto-detect.
checkpoint: "latest"
# For deep-ensemble eval, list multiple checkpoints; overrides checkpoint.
checkpoints: null
# Number of AR steps forward (each step = step_hours hours).
# 20 steps = 5 days at 6h. Max ~40 steps (10 days) fits in memory.
future_steps: 20
# Ensemble members per checkpoint. Paper uses 56 total (14×4 seeds).
ensemble_size: 8
# Batch size for the eval DataLoader. Keep at 1 for full-resolution.
batch_size: 1
# Number of DataLoader workers.
num_workers: 0
# Output directory for plots + eval_metrics.npz.
outdir: "rundir/fgn_2024_long/0/eval"
# Pooled-CRPS cell sizes (number of 0.25° grid cells per side).
# [4,8,16,32] ≈ [120, 240, 480, 960] km.
pool_sizes: [4, 8, 16, 32]
30 changes: 30 additions & 0 deletions examples/weather/fgn/config/fgn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Base FGN config — required by train.py's @hydra.main(config_name="fgn").
# Provides model defaults and a skeleton dataset section; override dataset
# and training fields on the command line or via a derived config such as
# fgn_arco.yaml.
#
# Minimal usage (all dataset fields required as overrides):
# python train.py \
# dataset.name=arco.ArcoFGNDataset \
# dataset.stats_path=/path/to/stats.npz \
# [dataset.train_start=... dataset.train_end=... ...]

defaults:
- training: default
- _self_

dataset:
name: ??? # required — e.g. arco.ArcoFGNDataset

model:
model_name: fgn
history_frames: 2
latent_dim: 16
hidden_channels: 32
background_channels: auto
invariant_channels: auto
group_norm_groups: 8
Loading