Skip to content

Commit 7601945

Browse files
authored
Merge pull request #1 from tcapelle/refactor
Refactor and simplify
2 parents a466028 + 5f984ad commit 7601945

11 files changed

Lines changed: 197 additions & 186 deletions

File tree

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
# Cloud Diffusion Experiment
66

7-
This codebase contains an implementation of a deep diffusion model applied to cloud images. It was developed as part of a research project exploring the potential of diffusion models
8-
for image generation and forecasting.
7+
This codebase contains an implementation of a deep diffusion model applied to cloud images. It was developed as part of a research project exploring the potential of diffusion models for image generation and forecasting.
98

109
## Setup
1110

@@ -18,6 +17,7 @@ for image generation and forecasting.
1817
To train the model, run `python train.py`. You can play with the parameters on top of the file to change the model architecture, training parameters, etc.
1918

2019
You can also override the configuration parameters by passing them as command-line arguments, e.g.
20+
2121
```bash
2222
> python train.py --epochs=10 --batch_size=32
2323
```
@@ -27,7 +27,6 @@ You can also override the configuration parameters by passing them as command-li
2727
This training is based on a Transformer based Unet (UViT), you can train the default model by running:
2828

2929
```bash
30-
3130
> python train_uvit.py
3231
```
3332

cloud_diffusion/dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
PROJECT_NAME = "ddpm_clouds"
1212
DATASET_ARTIFACT = 'capecape/gtc/np_dataset:v0'
1313

14+
class DummyNextFrameDataset:
15+
"Dataset that returns random images"
16+
def __init__(self, num_frames=4, img_size=64, N=1000):
17+
self.img_size = img_size
18+
self.num_frames = num_frames
19+
self.N = N
20+
21+
def __getitem__(self, idx):
22+
return torch.randn(self.num_frames, self.img_size, self.img_size)
23+
24+
def __len__(self):
25+
return self.N
26+
27+
1428
class CloudDataset:
1529
"""Dataset for cloud images
1630
It loads numpy files from wandb artifact and stacks them into a single array

cloud_diffusion/ddpm.py

Lines changed: 5 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
from pathlib import Path
21
from functools import partial
32

4-
import torch, wandb
5-
from torch.nn import init
6-
from torch.utils.data.dataloader import default_collate
7-
8-
import fastcore.all as fc
3+
import torch
94
from fastprogress import progress_bar
105

116
from diffusers.schedulers import DDIMScheduler
127

13-
from diffusers import UNet2DModel
14-
158

169
## DDPM params
1710
## From fastai V2 Course DDPM notebooks
@@ -21,52 +14,15 @@
2114
alphabar = alpha.cumprod(dim=0)
2215
sigma = beta.sqrt()
2316

24-
def noisify(x0, ):
25-
"Noise only the last frame"
26-
past_frames = x0[:,:-1]
27-
x0 = x0[:,-1:]
17+
def noisify_ddpm(x0):
18+
"Noise by ddpm"
2819
device = x0.device
2920
n = len(x0)
3021
t = torch.randint(0, n_steps, (n,), dtype=torch.long)
3122
ε = torch.randn(x0.shape, device=device)
32-
ᾱ_t = [t].reshape(-1, 1, 1, 1).to(device)
23+
ᾱ_t = alphabar[t].reshape(-1, 1, 1, 1).to(device)
3324
xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
34-
return torch.cat([past_frames, xt], dim=1), t.to(device), ε
35-
36-
def collate_ddpm(b):
37-
"Collate function that noisifies the last frame"
38-
return noisify(default_collate(b), alphabar)
39-
40-
def get_unet_params(model_name="unet_small", num_frames=4):
41-
"Return the parameters for the diffusers UNet2d model"
42-
if model_name == "unet_small":
43-
return dict(
44-
block_out_channels=(16, 32, 64, 128), # number of channels for each block
45-
norm_num_groups=8, # number of groups for the normalization layer
46-
in_channels=num_frames, # number of input channels
47-
out_channels=1, # number of output channels
48-
)
49-
elif model_name == "unet_big":
50-
return dict(
51-
block_out_channels=(32, 64, 128, 256), # number of channels for each block
52-
norm_num_groups=8, # number of groups for the normalization layer
53-
in_channels=num_frames, # number of input channels
54-
out_channels=1, # number of output channels
55-
)
56-
else:
57-
raise(f"Model name not found: {model_name}, choose between 'unet_small' or 'unet_big'")
58-
59-
def init_ddpm(model):
60-
"From Jeremy's bag of tricks on fastai V2 2023"
61-
for o in model.down_blocks:
62-
for p in o.resnets:
63-
p.conv2.weight.data.zero_()
64-
for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)
65-
66-
for o in model.up_blocks:
67-
for p in o.resnets: p.conv2.weight.data.zero_()
68-
69-
model.conv_out.weight.data.zero_()
25+
return xt, t.to(device), ε
7026

7127
@torch.no_grad()
7228
def diffusers_sampler(model, past_frames, sched, **kwargs):
@@ -88,23 +44,3 @@ def ddim_sampler(steps=350, eta=1.):
8844
ddim_sched = DDIMScheduler()
8945
ddim_sched.set_timesteps(steps)
9046
return partial(diffusers_sampler, sched=ddim_sched, eta=eta)
91-
92-
class UNet2D(UNet2DModel):
93-
def forward(self, *x, **kwargs):
94-
return super().forward(*x, **kwargs).sample ## Diffusers's UNet2DOutput class
95-
96-
@classmethod
97-
def from_checkpoint(cls, model_params, checkpoint_file):
98-
"Load a UNet2D model from a checkpoint file"
99-
model = cls(**model_params)
100-
model.load_state_dict(torch.load(checkpoint_file, map_location="cpu"))
101-
return model
102-
103-
104-
@classmethod
105-
def from_artifact(cls, model_params, artifact_name):
106-
"Load a UNet2D model from a wandb.Artifact, need to be run in a wandb run"
107-
artifact = wandb.use_artifact(artifact_name, type='model')
108-
artifact_dir = Path(artifact.download())
109-
chpt_file = list(artifact_dir.glob("*.pth"))[0]
110-
return cls.from_checkpoint(model_params, chpt_file)

cloud_diffusion/models.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from pathlib import Path
2+
3+
import wandb
4+
import fastcore.all as fc
5+
6+
import torch
7+
from torch import nn
8+
from diffusers import UNet2DModel
9+
10+
try:
11+
from denoising_diffusion_pytorch.simple_diffusion import UViT
12+
except:
13+
raise ImportError("Please install denoising_diffusion_pytorch with `pip install denoising_diffusion_pytorch`")
14+
15+
16+
def init_unet(model):
17+
"From Jeremy's bag of tricks on fastai V2 2023"
18+
for o in model.down_blocks:
19+
for p in o.resnets:
20+
p.conv2.weight.data.zero_()
21+
for p in fc.L(o.downsamplers): nn.init.orthogonal_(p.conv.weight)
22+
23+
for o in model.up_blocks:
24+
for p in o.resnets: p.conv2.weight.data.zero_()
25+
26+
model.conv_out.weight.data.zero_()
27+
28+
class WandbModel:
29+
"A model that can be saved to wandb"
30+
@classmethod
31+
def from_checkpoint(cls, model_params, checkpoint_file):
32+
"Load a UNet2D model from a checkpoint file"
33+
model = cls(**model_params)
34+
print(f"Loading model from: {checkpoint_file}")
35+
model.load_state_dict(torch.load(checkpoint_file))
36+
return model
37+
38+
@classmethod
39+
def from_artifact(cls, model_params, artifact_name):
40+
"Load a UNet2D model from a wandb.Artifact, need to be run in a wandb run"
41+
artifact = wandb.use_artifact(artifact_name, type='model')
42+
artifact_dir = Path(artifact.download())
43+
chpt_file = list(artifact_dir.glob("*.pth"))[0]
44+
return cls.from_checkpoint(model_params, chpt_file)
45+
46+
def get_unet_params(model_name="unet_small", num_frames=4):
47+
"Return the parameters for the diffusers UNet2d model"
48+
if model_name == "unet_small":
49+
return dict(
50+
block_out_channels=(16, 32, 64, 128), # number of channels for each block
51+
norm_num_groups=8, # number of groups for the normalization layer
52+
in_channels=num_frames, # number of input channels
53+
out_channels=1, # number of output channels
54+
)
55+
elif model_name == "unet_big":
56+
return dict(
57+
block_out_channels=(32, 64, 128, 256), # number of channels for each block
58+
norm_num_groups=8, # number of groups for the normalization layer
59+
in_channels=num_frames, # number of input channels
60+
out_channels=1, # number of output channels
61+
)
62+
else:
63+
raise(f"Model name not found: {model_name}, choose between 'unet_small' or 'unet_big'")
64+
65+
class UNet2D(UNet2DModel, WandbModel):
66+
def __init__(self, *x, **kwargs):
67+
super().__init__(*x, **kwargs)
68+
init_unet(self)
69+
70+
def forward(self, *x, **kwargs):
71+
return super().forward(*x, **kwargs).sample ## Diffusers's UNet2DOutput class
72+
73+
74+
## Simple Diffusion paper
75+
76+
def get_uvit_params(model_name="uvit_small", num_frames=4):
77+
"Return the parameters for the diffusers UViT model"
78+
if model_name == "uvit_small":
79+
return dict(
80+
dim=512,
81+
ff_mult=2,
82+
vit_depth=4,
83+
channels=4,
84+
patch_size=4,
85+
final_img_itransform=nn.Conv2d(num_frames,1,1)
86+
)
87+
elif model_name == "uvit_big":
88+
return dict(
89+
dim=1024,
90+
ff_mult=4,
91+
vit_depth=8,
92+
channels=4,
93+
patch_size=4,
94+
final_img_itransform=nn.Conv2d(num_frames,1,1)
95+
)
96+
else:
97+
raise(f"Model name not found: {model_name}, choose between 'uvit_small' or 'uvit_big'")
98+
99+
class UViTModel(UViT, WandbModel): pass

cloud_diffusion/simple_diffusion.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
from functools import partial
22

3-
import torch, math
4-
from torch import nn, sqrt
3+
import torch
4+
from torch import sqrt
55
from torch.special import expm1
6-
from torch.utils.data import DataLoader
7-
from torch.utils.data.dataloader import default_collate
86

97
from fastprogress import progress_bar
108

119
from einops import repeat
1210

1311
try:
14-
from denoising_diffusion_pytorch.simple_diffusion import (
15-
UViT, right_pad_dims_to, logsnr_schedule_cosine
16-
)
12+
from denoising_diffusion_pytorch.simple_diffusion import right_pad_dims_to, logsnr_schedule_cosine
1713
except:
1814
raise ImportError("Please install denoising_diffusion_pytorch with `pip install denoising_diffusion_pytorch`")
1915

@@ -26,51 +22,24 @@ def q_sample(x_start, times, noise):
2622

2723
return x_noised, log_snr
2824

29-
def noisify(frames, pred_objective="v"):
30-
past_frames = frames[:,:-1]
31-
last_frame = frames[:,-1:]
32-
device = frames.device
25+
def noisify_uvit(x0, pred_objective="v"):
26+
device = x0.device
3327

34-
noise = torch.randn_like(last_frame)
35-
times = torch.zeros((last_frame.shape[0],), device = device).float().uniform_(0, 1)
36-
x, log_snr = q_sample(last_frame, times, noise)
28+
noise = torch.randn_like(x0)
29+
times = torch.zeros((x0.shape[0],), device = device).float().uniform_(0, 1)
30+
x, log_snr = q_sample(x0, times, noise)
3731

3832
if pred_objective == 'v':
3933
padded_log_snr = right_pad_dims_to(x, log_snr)
4034
alpha, sigma = padded_log_snr.sigmoid().sqrt(), (-padded_log_snr).sigmoid().sqrt()
41-
target = alpha * noise - sigma * last_frame
35+
target = alpha * noise - sigma * x0
4236

4337
elif pred_objective == 'eps':
4438
target = noise
4539

46-
return torch.cat([past_frames, x], dim=1), log_snr, target
47-
48-
def collate_simple_diffusion(b):
49-
"Collate function that noisifies the last frame"
50-
return noisify(default_collate(b))
51-
52-
def get_uvit_params(model_name="uvit_small", num_frames=4):
53-
"Return the parameters for the diffusers UViT model"
54-
if model_name == "uvit_small":
55-
return dict(
56-
dim=512,
57-
ff_mult=2,
58-
vit_depth=4,
59-
channels=4,
60-
patch_size=4,
61-
final_img_itransform=nn.Conv2d(num_frames,1,1)
62-
)
63-
elif model_name == "uvit_big":
64-
return dict(
65-
dim=1024,
66-
ff_mult=4,
67-
vit_depth=8,
68-
channels=4,
69-
patch_size=4,
70-
final_img_itransform=nn.Conv2d(num_frames,1,1)
71-
)
72-
else:
73-
raise(f"Model name not found: {model_name}, choose between 'uvit_small' or 'uvit_big'")
40+
return x, log_snr, target
41+
42+
7443

7544
# Sampling functions
7645

@@ -138,4 +107,4 @@ def p_sample_loop(model, past_frames, steps=500):
138107
def simple_diffusion_sampler(steps=500):
139108
"""Returns a function that samples from the diffusion model using
140109
the simple diffusion sampling scheme"""
141-
return partial(p_sample_loop, steps=500)
110+
return partial(p_sample_loop, steps=steps)

0 commit comments

Comments
 (0)