From 3fb3cdb8f3966198fe43ea3f80193a7434a9074c Mon Sep 17 00:00:00 2001 From: jquetzalcoatl Date: Mon, 16 Mar 2026 14:13:01 -0700 Subject: [PATCH 01/15] gaussian grbm initialization --- dwave/plugins/torch/models/boltzmann_machine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index ff8739f..464495c 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -82,8 +82,8 @@ def __init__( self._idx_to_edge = {i: e for i, e in enumerate(self._edges)} self._edge_to_idx = {e: i for i, e in self._idx_to_edge.items()} - self._linear = torch.nn.Parameter(0.05 * (2 * torch.rand(self._n_nodes) - 1)) - self._quadratic = torch.nn.Parameter(5.0 * (2 * torch.rand(self._n_edges) - 1)) + self._linear = torch.nn.Parameter(torch.randn(self._n_nodes)/torch.tensor(self._n_nodes, dtype=torch.float).sqrt()) + self._quadratic = torch.nn.Parameter(torch.randn(self._n_edges)/torch.tensor(self._n_nodes, dtype=torch.float).sqrt()) edge_idx_i = torch.tensor([self._node_to_idx[i] for i, _ in self._edges]) edge_idx_j = torch.tensor([self._node_to_idx[j] for _, j in self._edges]) From 05cc617b3c0d1161ec0f39d54167433711ea87da Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:04:13 -0700 Subject: [PATCH 02/15] Update dwave/plugins/torch/models/boltzmann_machine.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- dwave/plugins/torch/models/boltzmann_machine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 464495c..310d174 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -82,8 +82,8 @@ def __init__( self._idx_to_edge = {i: e for i, e in enumerate(self._edges)} self._edge_to_idx = {e: i for i, e in self._idx_to_edge.items()} - self._linear = torch.nn.Parameter(torch.randn(self._n_nodes)/torch.tensor(self._n_nodes, dtype=torch.float).sqrt()) - self._quadratic = torch.nn.Parameter(torch.randn(self._n_edges)/torch.tensor(self._n_nodes, dtype=torch.float).sqrt()) + self._linear = torch.nn.Parameter(torch.randn(self._n_nodes)/self._n_nodes**0.5) + self._quadratic = torch.nn.Parameter(torch.randn(self._n_edges)/self._n_nodes**0.5) edge_idx_i = torch.tensor([self._node_to_idx[i] for i, _ in self._edges]) edge_idx_j = torch.tensor([self._node_to_idx[j] for _, j in self._edges]) From bd9fdab4346636acb448f58ce65f2e7f4d5ed015 Mon Sep 17 00:00:00 2001 From: jquetzalcoatl Date: Mon, 16 Mar 2026 18:23:55 -0700 Subject: [PATCH 03/15] added release note --- releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml diff --git a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml new file mode 100644 index 0000000..24741d1 --- /dev/null +++ b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + grbm weights and biases initialization set to Gaussian N(0,1/number of nodes) From 0c1ddeccdc3b80c5e95863ee886ef00a9a233f47 Mon Sep 17 00:00:00 2001 From: jquetzalcoatl Date: Mon, 16 Mar 2026 22:04:55 -0700 Subject: [PATCH 04/15] added release note --- .../notes/gaussian-rbm-init-28fd4d295ef86d77.yaml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml index 24741d1..fcd5e20 100644 --- a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml +++ b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml @@ -1,4 +1,16 @@ --- features: - | - grbm weights and biases initialization set to Gaussian N(0,1/number of nodes) + Initialize ``GraphRestrictedBoltzmannMachine`` weights using Gaussian + random variables with standard deviation equal to 1/sqrt(N), where N + denotes the number of nodes in the gRBM. +other: + - | + The initialization strategy is grounded in Hinton's practical guide for RBM training + ``_, which recommends sampling + weights from a Gaussian distribution with mean 0 and standard deviation 0.01. + The scaling factor of 1/sqrt(N) ensures that the energy functional remains extensive + and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model + ``_. + + From d4bdfbfd438ef376b46bba8e36c28b12f1159d96 Mon Sep 17 00:00:00 2001 From: jquetzalcoatl Date: Tue, 17 Mar 2026 10:07:02 -0700 Subject: [PATCH 05/15] added docstring explaining motivation for weight initialization in grbm --- dwave/plugins/torch/models/boltzmann_machine.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 310d174..831796f 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -49,6 +49,13 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): """Creates a graph-restricted Boltzmann machine. + The initialization strategy is grounded in Hinton's practical guide for RBM training + ``_, which recommends sampling + weights from a Gaussian distribution with mean 0 and standard deviation 0.01. + The scaling factor of 1/sqrt(N) ensures that the energy functional remains extensive + and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model + ``_. + Args: nodes (Iterable[Hashable]): List of nodes. edges (Iterable[tuple[Hashable, Hashable]]): List of edges. From 60ee81a1ff8f465956a010e63819d2dd321da684 Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:46:18 -0700 Subject: [PATCH 06/15] Update releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- .../notes/gaussian-rbm-init-28fd4d295ef86d77.yaml | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml index fcd5e20..8d1fa4b 100644 --- a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml +++ b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml @@ -3,14 +3,6 @@ features: - | Initialize ``GraphRestrictedBoltzmannMachine`` weights using Gaussian random variables with standard deviation equal to 1/sqrt(N), where N - denotes the number of nodes in the gRBM. -other: - - | - The initialization strategy is grounded in Hinton's practical guide for RBM training - ``_, which recommends sampling - weights from a Gaussian distribution with mean 0 and standard deviation 0.01. - The scaling factor of 1/sqrt(N) ensures that the energy functional remains extensive - and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model - ``_. + denotes the number of nodes in the GRBM. The weight-initialization strategy is grounded in `Hinton's practical guide for RBM training `_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the `Sherrington-Kirkpatrick model`_. From eee250aa39bf7afeb3d7c0cbb1a400640778e9f5 Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:46:44 -0700 Subject: [PATCH 07/15] Update dwave/plugins/torch/models/boltzmann_machine.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- dwave/plugins/torch/models/boltzmann_machine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 831796f..72f13fd 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -49,9 +49,9 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): """Creates a graph-restricted Boltzmann machine. - The initialization strategy is grounded in Hinton's practical guide for RBM training - ``_, which recommends sampling - weights from a Gaussian distribution with mean 0 and standard deviation 0.01. + The initialization-strategy is grounded in + `Hinton's practical guide for RBM training`_, which recommends sampling + weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of 1/sqrt(N) ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model ``_. From bc551b8383e964909e362ae166d77eaeb9c298de Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:46:53 -0700 Subject: [PATCH 08/15] Update dwave/plugins/torch/models/boltzmann_machine.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- dwave/plugins/torch/models/boltzmann_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 72f13fd..864f9fe 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -52,7 +52,7 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): The initialization-strategy is grounded in `Hinton's practical guide for RBM training`_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). - The scaling factor of 1/sqrt(N) ensures that the energy functional remains extensive + The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model ``_. From 8ec902e7f93f7980279acf0c7df4786dbc3e24df Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:47:06 -0700 Subject: [PATCH 09/15] Update dwave/plugins/torch/models/boltzmann_machine.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- dwave/plugins/torch/models/boltzmann_machine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 864f9fe..35c81de 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -53,8 +53,7 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): `Hinton's practical guide for RBM training`_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive - and initializes the graph RBM in a paramagnetic regime, consistent with the Sherrington-Kirkpatrick model - ``_. + and initializes the graph RBM in a paramagnetic regime, consistent with the ` Sherrington-Kirkpatrick model`_. Args: nodes (Iterable[Hashable]): List of nodes. From 54b2862171a4b574be9365767768b1d5cf15b60e Mon Sep 17 00:00:00 2001 From: jquetzalcoatl Date: Tue, 17 Mar 2026 10:50:31 -0700 Subject: [PATCH 10/15] Fixed biases initilization to zero and added docstring explaining motivation --- dwave/plugins/torch/models/boltzmann_machine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 35c81de..22e1e9e 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -54,6 +54,7 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the ` Sherrington-Kirkpatrick model`_. + The biases are initialized to zero to ensure extensiveness of the energy functional and to avoid introducing any initial preference for spin configurations. Args: nodes (Iterable[Hashable]): List of nodes. @@ -88,7 +89,7 @@ def __init__( self._idx_to_edge = {i: e for i, e in enumerate(self._edges)} self._edge_to_idx = {e: i for i, e in self._idx_to_edge.items()} - self._linear = torch.nn.Parameter(torch.randn(self._n_nodes)/self._n_nodes**0.5) + self._linear = torch.nn.Parameter(torch.zeros(self._n_nodes)) self._quadratic = torch.nn.Parameter(torch.randn(self._n_edges)/self._n_nodes**0.5) edge_idx_i = torch.tensor([self._node_to_idx[i] for i, _ in self._edges]) From 8cde43768e037bacaec6b5cd5f409a9eca01ae12 Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:59:25 -0700 Subject: [PATCH 11/15] Update dwave/plugins/torch/models/boltzmann_machine.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- dwave/plugins/torch/models/boltzmann_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 22e1e9e..00acb28 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -53,7 +53,7 @@ class GraphRestrictedBoltzmannMachine(torch.nn.Module): `Hinton's practical guide for RBM training`_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive - and initializes the graph RBM in a paramagnetic regime, consistent with the ` Sherrington-Kirkpatrick model`_. + and initializes the GRBM in a paramagnetic regime, consistent with the `Sherrington-Kirkpatrick model`_. The biases are initialized to zero to ensure extensiveness of the energy functional and to avoid introducing any initial preference for spin configurations. Args: From 4e48419bed09834e903c1349bfb8612097412655 Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:59:47 -0700 Subject: [PATCH 12/15] Update releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml index 8d1fa4b..29b9dd8 100644 --- a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml +++ b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml @@ -2,7 +2,7 @@ features: - | Initialize ``GraphRestrictedBoltzmannMachine`` weights using Gaussian - random variables with standard deviation equal to 1/sqrt(N), where N + random variables with standard deviation equal to :math:`1/\sqrt(N)`, where N denotes the number of nodes in the GRBM. The weight-initialization strategy is grounded in `Hinton's practical guide for RBM training `_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the `Sherrington-Kirkpatrick model`_. From d9a399caf3d91040142ab03410cc11d6cadea52d Mon Sep 17 00:00:00 2001 From: Javier Toledo Marin <36744342+jquetzalcoatl@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:59:58 -0700 Subject: [PATCH 13/15] Update releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml index 29b9dd8..ea450d5 100644 --- a/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml +++ b/releasenotes/notes/gaussian-rbm-init-28fd4d295ef86d77.yaml @@ -3,6 +3,6 @@ features: - | Initialize ``GraphRestrictedBoltzmannMachine`` weights using Gaussian random variables with standard deviation equal to :math:`1/\sqrt(N)`, where N - denotes the number of nodes in the GRBM. The weight-initialization strategy is grounded in `Hinton's practical guide for RBM training `_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the graph RBM in a paramagnetic regime, consistent with the `Sherrington-Kirkpatrick model`_. + denotes the number of nodes in the GRBM. The weight-initialization strategy is grounded in `Hinton's practical guide for RBM training `_, which recommends sampling weights from a Gaussian distribution with mean 0 and standard deviation 0.01 (for zero-one-valued RBMs). The scaling factor of :math:`1/\sqrt(N)` ensures that the energy functional remains extensive and initializes the GRBM in a paramagnetic regime, consistent with the `Sherrington-Kirkpatrick model`_. From 0e4761d79518be4013217a6a44bf656d3988b8fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladimir=20Vargas=20Calder=C3=B3n?= Date: Wed, 8 Apr 2026 14:11:07 -0700 Subject: [PATCH 14/15] Enforce deterministic latent mapping in tests for reproducibility of forward method unit test --- tests/test_dvae_winci2020.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 38dfff7..3cf9730 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -78,16 +78,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # are the models themselves latent_dims_list = [1, 2] self.encoders = {i: Encoder(i) for i in latent_dims_list} - # self.decoders is independent of number of latent dims, but we also create a dict to separate - # them + # self.decoders is independent of number of latent dims, but we also create a dict to + # separate them self.decoders = {i: Decoder(latent_features, input_features) for i in latent_dims_list} - # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models - # themselves - - self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list} - - # Now we also create a DVAE with a trainable Encoder def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: # straight-through estimator that maps positive logits to 1 and negative logits to -1 hard = torch.sign(logits) @@ -95,6 +89,14 @@ def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> to result = hard - soft.detach() + soft # Now we need to repeat the result n_samples times along a new dimension return repeat(result, "b ... -> b n ...", n=n_samples) + # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the + # models themselves + + self.dvaes = {i: DVAE( + self.encoders[i], self.decoders[i], latent_to_discrete=deterministic_latent_to_discrete + ) for i in latent_dims_list} + + # Now we also create a DVAE with a trainable Encoder self.dvae_with_trainable_encoder = DVAE( encoder=torch.nn.Linear(input_features, latent_features), From c6e98c8cbcb77d8fc624c9ea761c82266222bf0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladimir=20Vargas=20Calder=C3=B3n?= Date: Wed, 8 Apr 2026 14:22:36 -0700 Subject: [PATCH 15/15] Ensure reproducibility in forward method --- tests/test_dvae_winci2020.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 3cf9730..e22cd39 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -81,7 +81,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # self.decoders is independent of number of latent dims, but we also create a dict to # separate them self.decoders = {i: Decoder(latent_features, input_features) for i in latent_dims_list} + # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the + # models themselves + + self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list} + # Now we also create a DVAE with a trainable Encoder def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: # straight-through estimator that maps positive logits to 1 and negative logits to -1 hard = torch.sign(logits) @@ -89,14 +94,6 @@ def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> to result = hard - soft.detach() + soft # Now we need to repeat the result n_samples times along a new dimension return repeat(result, "b ... -> b n ...", n=n_samples) - # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the - # models themselves - - self.dvaes = {i: DVAE( - self.encoders[i], self.decoders[i], latent_to_discrete=deterministic_latent_to_discrete - ) for i in latent_dims_list} - - # Now we also create a DVAE with a trainable Encoder self.dvae_with_trainable_encoder = DVAE( encoder=torch.nn.Linear(input_features, latent_features), @@ -250,19 +247,22 @@ def test_latent_to_discrete(self, n_samples, expected): @parameterized.expand([(i, j) for i in range(1, 3) for j in [0, 1, 5, 1000]]) def test_forward(self, n_latent_dims, n_samples): """Test the forward method.""" + torch.manual_seed(1234) # Set seed for reproducibility of latent_to_discrete sampling expected_latents = self.encoders[n_latent_dims](self.data) expected_discretes = self.dvaes[n_latent_dims].latent_to_discrete( expected_latents, n_samples ) expected_reconstructed_x = self.decoders[n_latent_dims](expected_discretes) + torch.manual_seed(1234) # Set seed again to ensure that the sampling in the forward method + # is the same as in the expected_discretes latents, discretes, reconstructed_x = self.dvaes[n_latent_dims].forward( x=self.data, n_samples=n_samples ) + torch.testing.assert_close(latents, expected_latents) + torch.testing.assert_close(discretes, expected_discretes) + torch.testing.assert_close(reconstructed_x, expected_reconstructed_x) - assert torch.equal(reconstructed_x, expected_reconstructed_x) - assert torch.equal(discretes, expected_discretes) - assert torch.equal(latents, expected_latents) if __name__ == "__main__":