Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rbms/bernoulli_gaussian/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
from botocore.vendored.six import u

import numpy as np
import torch
Expand Down
2 changes: 1 addition & 1 deletion rbms/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __eq__(self, other: object) -> bool:
return False
other_params = other.named_parameters()
for k, v in self.named_parameters().items():
if not np.equal(other_params[k], v):
if not np.equal(other_params[k], v).all():
return False
return True

Expand Down
12 changes: 12 additions & 0 deletions rbms/custom_fn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import h5py
import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -47,3 +49,13 @@ def check_keys_dict(d: dict, names: list[str]):
raise ValueError(
f"""Dictionary params missing key '{k}'\n Provided keys : {d.keys()}\n Expected keys: {names}"""
)


def load_string(f: h5py.Dataset, k: str | bytes) -> str:
# Fix 1: Ensure key is a string
# key = k.decode("utf-8") if isinstance(k, bytes) else k
val = np.asarray(f[k])
# Fix 2: Ensure string values (like 'Reservoir') are strings, not bytes
if val.dtype.kind in ["S", "V", "O"]: # Bytes, Void, or Object (StringDType)
val = val.astype(str)
return str(val)
6 changes: 4 additions & 2 deletions rbms/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -97,10 +98,11 @@ def get_covariance_matrix(
"""
num_data = len(data)
num_classes = int(data.max().item() + 1)
dtype = data.dtype

if weights is None:
weights = torch.ones(num_data)
weights = weights.to(device=device, dtype=torch.float32)
weights = weights.to(device=device, dtype=dtype)

if num_extract is not None:
idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False)
Expand All @@ -112,7 +114,7 @@ def get_covariance_matrix(
data = data.to(device=device, dtype=torch.int32)
data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1)
else:
data_oh = data.to(device=device, dtype=torch.float32)
data_oh = data.to(device=device, dtype=dtype)

norm_weights = weights.reshape(-1, 1) / weights.sum()
data_mean = (data_oh * norm_weights).sum(0, keepdim=True)
Expand Down
4 changes: 3 additions & 1 deletion rbms/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# from ptt.optim.cossim import SGD_cossim
from torch import Tensor
from torch.optim import SGD, Optimizer
from torch.optim import SGD, Adam, Optimizer

from rbms.classes import EBM

Expand Down Expand Up @@ -60,6 +60,8 @@ def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]:
optim_class = SGD
case "cossim":
optim_class = SGD_cossim
case "adam":
optim_class = Adam
case _:
print(f"Unrecognized optimizer {args['optim']}, falling back to SGD.")
optim_class = SGD
Expand Down
19 changes: 15 additions & 4 deletions rbms/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def plot_scatter_labels(ax, data_proj, gen_data_proj, proj1, proj2, labels):
)


def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="vertical"):
def plot_hist(
ax, data_proj, gen_data_proj, color, proj, labels, orientation="vertical", log=False
):
"""Args:
ax
data_proj
Expand All @@ -56,6 +58,7 @@ def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="ve
density=True,
orientation=orientation,
lw=1,
log=log,
)
ax.hist(
gen_data_proj[:, proj],
Expand All @@ -67,11 +70,12 @@ def plot_hist(ax, data_proj, gen_data_proj, color, proj, labels, orientation="ve
density=True,
orientation=orientation,
lw=1.5,
log=log,
)
ax.axis("off")


def plot_PCA(data1, data2, labels, dir1=0, dir2=1):
def plot_PCA(data1, data2, labels, dir1=0, dir2=1, log=False):
"""Args:
data1
data2
Expand All @@ -87,9 +91,16 @@ def plot_PCA(data1, data2, labels, dir1=0, dir2=1):
ax_hist_y = fig.add_subplot(gs[1:4, 3])

plot_scatter_labels(ax_scatter, data1, data2, dir1, dir2, labels=labels)
plot_hist(ax_hist_x, data1, data2, "red", dir1, labels=labels)
plot_hist(ax_hist_x, data1, data2, "red", dir1, labels=labels, log=log)
plot_hist(
ax_hist_y, data1, data2, "red", dir2, orientation="horizontal", labels=labels
ax_hist_y,
data1,
data2,
"red",
dir2,
orientation="horizontal",
labels=labels,
log=log,
)

ax_hist_x.legend(fontsize=12, bbox_to_anchor=(1, 1))
Expand Down
16 changes: 14 additions & 2 deletions rbms/scripts/split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,26 @@ def split_data_train_test(
permutation_index = rng.permutation(num_samples)
n_sample_train = int(train_size * num_samples)

data_train = data[permutation_index[:n_sample_train]].int().cpu().numpy()
data_train = data[permutation_index[:n_sample_train]].cpu().numpy()
names_train = names[permutation_index[:n_sample_train]]
labels_train = labels[permutation_index[:n_sample_train]].int().cpu().numpy()

data_test = data[permutation_index[n_sample_train:]].int().cpu().numpy()
data_test = data[permutation_index[n_sample_train:]].cpu().numpy()
names_test = names[permutation_index[n_sample_train:]]
labels_test = labels[permutation_index[n_sample_train:]].int().cpu().numpy()

match dataset.variable_type:
case "bernoulli":
print("Casting data to bool")
data_train = data_train.astype(bool)
data_test = data_test.astype(bool)
case "categorical" | "ising":
print("Casting data to int")
data_train = data_train.astype(int)
data_test = data_test.astype(int)
case _:
print("Not casting data")

print(
f" train_size = {data_train.shape[0]} ({100 * data_train.shape[0] / data.shape[0]}%)"
)
Expand Down
4 changes: 1 addition & 3 deletions rbms/training/pcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from rbms.classes import EBM, Sampler
from rbms.dataset.dataset_class import RBMDataset
from rbms.io import save_model, save_sampler
from rbms.training.utils import EarlyStopper


@torch.compile(dynamic=True, disable=True)
@torch.no_grad
# @torch.no_grad
def train(
train_dataset: RBMDataset,
test_dataset: RBMDataset,
Expand Down
Loading