From 953271f09f4eeb8af8068ecfc4c5d922ff6951d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 11 Mar 2026 00:38:30 +0100 Subject: [PATCH 1/7] add a load_string function --- rbms/bernoulli_gaussian/classes.py | 1 - rbms/custom_fn.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index 4c375dc..e634f58 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -1,5 +1,4 @@ from __future__ import annotations -from botocore.vendored.six import u import numpy as np import torch diff --git a/rbms/custom_fn.py b/rbms/custom_fn.py index eabb39e..7bc1fcf 100644 --- a/rbms/custom_fn.py +++ b/rbms/custom_fn.py @@ -1,3 +1,5 @@ +import h5py +import numpy as np import torch from torch import Tensor @@ -47,3 +49,13 @@ def check_keys_dict(d: dict, names: list[str]): raise ValueError( f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}""" ) + + +def load_string(f: h5py.Dataset, k: str | bytes) -> str: + # Fix 1: Ensure key is a string + # key = k.decode("utf-8") if isinstance(k, bytes) else k + val = np.asarray(f[k]) + # Fix 2: Ensure string values (like 'Reservoir') are strings, not bytes + if val.dtype.kind in ["S", "V", "O"]: # Bytes, Void, or Object (StringDType) + val = val.astype(str) + return str(val) From 5a3e7e6e540a3d2aec4540381ee4b149818828cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:40:31 +0200 Subject: [PATCH 2/7] log scale histogram pca --- rbms/plot.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/rbms/plot.py b/rbms/plot.py index 9dfc49b..2293154 100644 --- a/rbms/plot.py +++ b/rbms/plot.py @@ -36,7 +36,9 @@ def plot_scatter_labels(ax, data_proj, gen_data_proj, proj1, proj2, labels): ) -def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="vertical"): +def plot_hist( + ax, data_proj, gen_data_proj, color, proj, labels, orientation="vertical", log=False +): """Args: ax data_proj @@ -56,6 +58,7 @@ def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="ve density=True, orientation=orientation, lw=1, + log=log, ) ax.hist( gen_data_proj[:, proj], @@ -67,11 +70,12 @@ def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="ve density=True, orientation=orientation, lw=1.5, + log=log, ) ax.axis("off") -def plot_PCA(data1, data2, labels, dir1=0, dir2=1): +def plot_PCA(data1, data2, labels, dir1=0, dir2=1, log=False): """Args: data1 data2 @@ -87,9 +91,16 @@ def plot_PCA(data1, data2, labels, dir1=0, dir2=1): ax_hist_y = fig.add_subplot(gs[1:4, 3]) plot_scatter_labels(ax_scatter, data1, data2, dir1, dir2, labels=labels) - plot_hist(ax_hist_x, data1, data2, "red", dir1, labels=labels) + plot_hist(ax_hist_x, data1, data2, "red", dir1, labels=labels, log=log) plot_hist( - ax_hist_y, data1, data2, "red", dir2, orientation="horizontal", labels=labels + ax_hist_y, + data1, + data2, + "red", + dir2, + orientation="horizontal", + labels=labels, + log=log, ) ax_hist_x.legend(fontsize=12, bbox_to_anchor=(1, 1)) From e3dda0ceffba08e346cb995cfaebb6dae2ba1536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:40:56 +0200 Subject: [PATCH 3/7] fix eq --- rbms/classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rbms/classes.py b/rbms/classes.py index 3e5d9ba..aaaeee2 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -36,7 +36,7 @@ def __eq__(self, other: object) -> bool: return False other_params = other.named_parameters() for k, v in self.named_parameters().items(): - if not np.equal(other_params[k], v): + if not np.equal(other_params[k], v).all(): return False return True From 3a5304a4c67cd935b834e220de596c1dce8fa43e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:41:12 +0200 Subject: [PATCH 4/7] add Adam --- rbms/optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rbms/optim.py b/rbms/optim.py index fc39f34..e2d708e 100644 --- a/rbms/optim.py +++ b/rbms/optim.py @@ -3,7 +3,7 @@ # from ptt.optim.cossim import SGD_cossim from torch import Tensor -from torch.optim import SGD, Optimizer +from torch.optim import SGD, Adam, Optimizer from rbms.classes import EBM @@ -60,6 +60,8 @@ def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]: optim_class = SGD case "cossim": optim_class = SGD_cossim + case "adam": + optim_class = Adam case _: print(f"Unrecognized optimizer {args['optim']}, falling back to SGD.") optim_class = SGD From 5d3e21ab4f8b0dd19ca110c3b62cc175cf0c6330 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:41:47 +0200 Subject: [PATCH 5/7] save binary data as bool or int to save space --- rbms/scripts/split_data.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/rbms/scripts/split_data.py b/rbms/scripts/split_data.py index b49dbda..3783702 100644 --- a/rbms/scripts/split_data.py +++ b/rbms/scripts/split_data.py @@ -90,14 +90,26 @@ def split_data_train_test( permutation_index = rng.permutation(num_samples) n_sample_train = int(train_size * num_samples) - data_train = data[permutation_index[:n_sample_train]].int().cpu().numpy() + data_train = data[permutation_index[:n_sample_train]].cpu().numpy() names_train = names[permutation_index[:n_sample_train]] labels_train = labels[permutation_index[:n_sample_train]].int().cpu().numpy() - data_test = data[permutation_index[n_sample_train:]].int().cpu().numpy() + data_test = data[permutation_index[n_sample_train:]].cpu().numpy() names_test = names[permutation_index[n_sample_train:]] labels_test = labels[permutation_index[n_sample_train:]].int().cpu().numpy() + match dataset.variable_type: + case "bernoulli": + print("Casting data to bool") + data_train = data_train.astype(bool) + data_test = data_test.astype(bool) + case "categorical" | "ising": + print("Casting data to int") + data_train = data_train.astype(int) + data_test = data_test.astype(int) + case _: + print("Not casting data") + print( f" train_size = {data_train.shape[0]} ({100 * data_train.shape[0] / data.shape[0]}%)" ) From 54a374d306c5f14c6235c59ec932a8db9c1ebde1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:42:09 +0200 Subject: [PATCH 6/7] remove torch compile --- rbms/training/pcd.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 2cae466..e1c2bf9 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -8,11 +8,9 @@ from rbms.classes import EBM, Sampler from rbms.dataset.dataset_class import RBMDataset from rbms.io import save_model, save_sampler -from rbms.training.utils import EarlyStopper -@torch.compile(dynamic=True, disable=True) -@torch.no_grad +# @torch.no_grad def train( train_dataset: RBMDataset, test_dataset: RBMDataset, From f50025aa4be245bf0e77cc1520f0a8c466a95324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Sat, 4 Apr 2026 23:42:27 +0200 Subject: [PATCH 7/7] dtype covariance --- rbms/dataset/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rbms/dataset/utils.py b/rbms/dataset/utils.py index 98f4faf..8dee55f 100644 --- a/rbms/dataset/utils.py +++ b/rbms/dataset/utils.py @@ -1,4 +1,5 @@ from collections.abc import Callable + import numpy as np import torch from torch import Tensor @@ -97,10 +98,11 @@ def get_covariance_matrix( """ num_data = len(data) num_classes = int(data.max().item() + 1) + dtype = data.dtype if weights is None: weights = torch.ones(num_data) - weights = weights.to(device=device, dtype=torch.float32) + weights = weights.to(device=device, dtype=dtype) if num_extract is not None: idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False) @@ -112,7 +114,7 @@ def get_covariance_matrix( data = data.to(device=device, dtype=torch.int32) data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1) else: - data_oh = data.to(device=device, dtype=torch.float32) + data_oh = data.to(device=device, dtype=dtype) norm_weights = weights.reshape(-1, 1) / weights.sum() data_mean = (data_oh * norm_weights).sum(0, keepdim=True)