Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Copy link
Copy Markdown
Collaborator

@hanaol hanaol Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had merged requirements.txt dependencies into pyproject.toml. So, you can delete it.

Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
numpy==2.3.3
scikit-learn==1.7.2
torch==2.9.1
torch==2.9.0
torchvision==0.24.0
pyscf==2.10.0
lightning== 2.5.6
wandb==0.12.10
wandb==0.16.0
pyyaml==6.0.3
monty==2025.3.3
mp-pyrho==0.5.1
Expand Down
12 changes: 8 additions & 4 deletions src/electrai/dataloader/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
data_path: path of input chgcar or elfcar files.
label_path: path of label chgcar or elfcar files.
map_path: path of json file mapping functional to list of task_ids.
functional: 'GGA', 'GG+U', 'PBEsol', 'SCAN', 'r2SCAN'.
functional: 'GGA', 'GG+U', 'PBEsol', 'SCAN', 'r2SCAN'. #TODO: Understand better
train_fraction: fraction of the data used for training (0 to 1).
"""
self.data_path = Path(data_path)
Expand Down Expand Up @@ -70,10 +70,11 @@ def __init__(
"""
Parameters
----------
#TODO: These paremeters are not actually used in the code. Update these.
data: list of voxel data of length batch_size.
rho_type: chgcar or elfcar.
data_size: target size of data.
label_size: target size of label.
rho_type: chgcar or elfcar. #TODO: elfcar not actually supported yet
data_size: target size of data. #TODO: Is this the input grid size? (independent variables)
label_size: target size of label. #TODO: Is this the input grid size? (dependent variables)
pyrho_uf: pyrho upsampling factor
"""
self.data = data
Expand All @@ -85,6 +86,8 @@ def __init__(
def __len__(self):
return len(self.data)

#TODO: These seem to only rotate by 90 degrees. Can we do partial rotations?
#TODO: Can we do this in Pymatgen? (Via a library)
def rotate_x(self, data_in):
"""
rotate 90 by x axis
Expand Down Expand Up @@ -126,6 +129,7 @@ def __getitem__(self, idx: int):
data = self.read_data(self.data[idx][0])
label = self.read_data(self.data[idx][1])

#TODO: Need to normalize by volume (maybe via pymatgen?) There is a PR for this.
if self.rho_type == "chgcar":
data = data.data["total"] / np.prod(data.data["total"].shape)
label = label.data["total"] / np.prod(label.data["total"].shape)
Expand Down
7 changes: 4 additions & 3 deletions src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def train(args):
else:
wandb_logger = None

#TODO: make checkpoints optional but default to True
checkpoint_cb = ModelCheckpoint(
monitor="val_loss",
save_top_k=2,
Expand All @@ -80,10 +81,10 @@ def train(args):
logger=wandb_logger,
callbacks=[checkpoint_cb, lr_monitor],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
devices=1, #TODO: multiple GPUs and put it in the config
precision=cfg.model_precision,
log_every_n_steps=1,
gradient_clip_val=getattr(cfg, "gradient_clip_value", 1.0),
log_every_n_steps=1, #TODO: put it in the config
gradient_clip_val=getattr(cfg, "gradient_clip_value", 1.0),
)

# -----------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, cfg):
n_residual_blocks=int(cfg.n_residual_blocks),
n_upscale_layers=int(cfg.n_upscale_layers),
C=int(cfg.n_channels),
K1=int(cfg.kernel_size1),
K2=int(cfg.kernel_size2),
K1=int(cfg.kernel_size1), #TODO: Understand better
K2=int(cfg.kernel_size2), #TODO: Understand better
normalize=cfg.normalize,
use_checkpoint=getattr(cfg, "use_checkpoint", True),
)
Expand Down
2 changes: 2 additions & 0 deletions src/electrai/model/srgan_layernorm_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,14 @@ def __init__(
)

def forward(self, x):
#TODO: variable names could be improved
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
#TODO: Understand better -- is this the right normalization we want to do? What are the units?
if self.normalize:
upscale_factor = 8 ** (self.n_upscale_layers)
out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None]
Expand Down