From 36117c4671d8b9dcc995db263955f239b623db35 Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 12 Nov 2025 22:43:50 +0000 Subject: [PATCH 01/12] Add maximum mean discrepancy and radial basis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Vladimir Vargas Calderón --- dwave/plugins/torch/models/losses/mmd.py | 214 ++++++++++++++++++ ...dd-mmd-loss-function-3fa9e9a2cb452391.yaml | 10 + tests/requirements.txt | 1 + tests/test_dvae_winci2020.py | 67 ++++++ 4 files changed, 292 insertions(+) create mode 100755 dwave/plugins/torch/models/losses/mmd.py create mode 100644 releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py new file mode 100755 index 0000000..39075b7 --- /dev/null +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -0,0 +1,214 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.modules.utils import store_config + +__all__ = ["Kernel", "RadialBasisFunction", "mmd_loss", "MMDLoss"] + + +class Kernel(nn.Module): + """Base class for kernels. + + Kernels are functions that compute a similarity measure between data points. Any ``Kernel`` + subclass must implement the ``_kernel`` method, which computes the kernel matrix for a given + input multi-dimensional tensor with shape (n, f1, f2, ...), where n is the number of items + and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) + containing the pairwise kernel values. + """ + + @abstractmethod + def _kernel(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) + containing the pairwise kernel values. + + Args: + x (torch.Tensor): A (n, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n, n) tensor. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes kernels for intra and inter set pairs in ``x`` and ``y``. In general, ``x`` and + ``y`` are (n_x, f1, f2, ...) and (n_y, f1, f2, ...) shaped tensors, and the output is a + (n_x + n_y, n_x + n_y) shaped tensor containing the pairwise kernel values. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. + """ + if x.shape[1:] != y.shape[1:]: + raise ValueError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) + # Concatenate along batch dimension + xy = torch.cat([x, y], dim=0) + return self._kernel(xy) + + +class RadialBasisFunction(Kernel): + """Radial basis function kernel. + + This kernel between two data points x and y is defined as + :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth + parameter. + + This implementation considers aggregating multiple radial basis function kernels with different + bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of + multipliers. The base bandwidth can be provided directly or estimated from the data using the + average distance between samples. + + Args: + num_features (int): Number of kernel bandwidths to use. + mul_factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are + computed as :math:`\sigma_i = \sigma * mul\_factor^{i - num\_features // 2}` for + :math:`i` in ``[0, num_features - 1]``. Defaults to 2.0. + bandwidth (float | None): Base bandwidth parameter. If None, the bandwidth is estimated + from the data. Defaults to None. + """ + + @store_config + def __init__( + self, num_features: int, mul_factor: int | float = 2.0, bandwidth: Optional[float] = None + ): + super().__init__() + bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2) + self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) + self.bandwidth = bandwidth + + @torch.no_grad() + def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: + """ + Computes the base bandwidth parameter as the average distance between samples if the + bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. + See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking + the average distance as the bandwidth. + + Args: + l2_distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise + L2 distances between samples. If it is None and the bandwidth is not provided, an + error will be raised. Defaults to None. + + Returns: + torch.Tensor | float: The base bandwidth parameter. + """ + if self.bandwidth is None: + num_samples = l2_distance_matrix.shape[0] + return l2_distance_matrix.sum() / (num_samples**2 - num_samples) + return self.bandwidth + + def _kernel(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the radial basis function kernel as + + .. math:: + k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), + + where :math:`\sigma_i` are the bandwidths. + + Args: + x (torch.Tensor): A (n, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n, n) tensor representing the kernel matrix. + """ + distance_matrix = torch.cdist(x, x, p=2) + bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers + return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) + + +def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """ + Computes the maximum mean discrepancy (MMD) loss between two sets of samples x and y. + + This is a two-sample test to test the null hypothesis that the two samples are drawn from the + same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is + defined as + + .. math:: + MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + + where :math:`\varphi` is a feature map associated with the kernel function + :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the + distributions of the samples. It follows that, in terms of the kernel function, the squared MMD + can be computed as + + .. math:: + E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. + + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. In machine learning applications, the MMD can be + used as a loss function to compare the distribution of model-generated samples to the + distribution of real data samples to force model-generated samples to match the real data + distribution. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + kernel (Kernel): A kernel function object. + + Returns: + torch.Tensor: The computed MMD loss. + """ + num_x = x.shape[0] + num_y = y.shape[0] + kernel_matrix = kernel(x, y) + kernel_xx = kernel_matrix[:num_x, :num_x] + kernel_yy = kernel_matrix[num_x:, num_x:] + kernel_xy = kernel_matrix[:num_x, num_x:] + xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) + yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) + xy = kernel_xy.sum() / (num_x * num_y) + return xx + yy - 2 * xy + + +class MMDLoss(nn.Module): + """ + Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of + samples. + + This uses the `mmd_loss` function to compute the loss. + + Args: + kernel (Kernel): A kernel function object. + """ + + @store_config + def __init__(self, kernel: Kernel): + super().__init__() + self.kernel = kernel + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the MMD loss between two sets of samples x and y. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + + Returns: + torch.Tensor: The computed MMD loss. + """ + return mmd_loss(x, y, self.kernel) diff --git a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml new file mode 100644 index 0000000..a431a7d --- /dev/null +++ b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + MMD loss is available in ``dwave.plugins.torch.models.losses.mmd.mmd_loss``, + which computes the MMD loss using a ``dwave.plugins.torch.models.losses.mmd.Kernel`` + (specialized to the ``dwave.plugins.torch.models.losses.mmd.RBFKernel``). This + enables training encoders in discrete variational autoencoders to match the + distribution of the prior model. + + \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt index b2bd102..d7abc8f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ coverage codecov parameterized +einops diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 8cd5db8..d9b0bb0 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -15,6 +15,7 @@ import unittest import torch +from einops import repeat from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine @@ -22,6 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss +from dwave.plugins.torch.models.losses.mmd import MMDLoss, RadialBasisFunction, mmd_loss from dwave.samplers import SimulatedAnnealingSampler @@ -84,6 +86,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list} + # Now we also create a DVAE with a trainable Encoder + def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: + # straight-through estimator that maps positive logits to 1 and negative logits to -1 + hard = torch.sign(logits) + soft = logits + result = hard - soft.detach() + soft + # Now we need to repeat the result n_samples times along a new dimension + return repeat(result, "b ... -> b n ...", n=n_samples) + + self.dvae_with_trainable_encoder = DVAE( + encoder=torch.nn.Linear(input_features, latent_features), + decoder=Decoder(latent_features, input_features), + latent_to_discrete=deterministic_latent_to_discrete, + ) + + self.fixed_boltzmann_machine = GraphRestrictedBoltzmannMachine( + nodes=(0, 1), + edges=[(0, 1)], + linear={0: 0.0, 1: 0.0}, + quadratic={(0, 1): 0.0}, + ) # Creates a uniform distribution over spin strings of length 2 + self.boltzmann_machine = GraphRestrictedBoltzmannMachine( nodes=(0, 1), edges=[(0, 1)], @@ -110,6 +134,49 @@ def test_mappings(self): # map [0, 1] to [-1, 1]: torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3]) + @parameterized.expand([True, False]) + def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): + """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" + dvae = self.dvae_with_trainable_encoder + optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) + kernel = RadialBasisFunction(num_features=5, mul_factor=2.0, bandwidth=None) + # Before training, the encoder will not map data points to the correct spin strings: + expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertNotEqual(discretes_set, expected_set) + mmd_loss_module = None + # Train the encoder so that the latent distribution matches the prior GRBM distribution + for _ in range(1000): + optimiser.zero_grad() + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.reshape(discretes.shape[0], -1) + prior_samples = self.fixed_boltzmann_machine.sample( + sampler=self.sampler_sa, + as_tensor=True, + device=discretes.device, + prefactor=1.0, + linear_range=None, + quadratic_range=None, + sample_params=dict(num_sweeps=10, seed=1234, num_reads=100), + ) + if use_mmd_loss_class: + if mmd_loss_module is None: + mmd_loss_module = MMDLoss(kernel) + mmd = mmd_loss_module(discretes, prior_samples) + else: + mmd = mmd_loss(discretes, prior_samples, kernel) + mmd.backward() + optimiser.step() + # After training, the encoder should map data points to spin strings that match the samples + # from the prior GRBM. Since the prior GRBM is uniform over spin strings of length 2, the + # encoder should map the four data points to the four spin strings (in any order). + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertEqual(discretes_set, expected_set) + @parameterized.expand([1, 2]) def test_train(self, n_latent_dims): """Test training simple dataset.""" From 56b423ff48eac943ec0a6c730055713df3b35929 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 17:45:01 +0000 Subject: [PATCH 02/12] Rename acronyms and fix first-line in docstring --- dwave/plugins/torch/models/losses/mmd.py | 41 ++++++++++++------------ tests/test_dvae_winci2020.py | 6 ++-- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py index 39075b7..1fcb737 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -20,7 +20,7 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", "mmd_loss", "MMDLoss"] +__all__ = ["Kernel", "RadialBasisFunction", "maximum_mean_discrepancy", "MaximumMeanDiscrepancy"] class Kernel(nn.Module): @@ -35,26 +35,28 @@ class Kernel(nn.Module): @abstractmethod def _kernel(self, x: torch.Tensor) -> torch.Tensor: - """ + """Perform a pairwise kernel evaluation over samples. + Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) containing the pairwise kernel values. Args: - x (torch.Tensor): A (n, f1, f2, ...) tensor. + x (torch.Tensor): A (n, f1, f2, ..., fk) tensor. Returns: torch.Tensor: A (n, n) tensor. """ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes kernels for intra and inter set pairs in ``x`` and ``y``. In general, ``x`` and - ``y`` are (n_x, f1, f2, ...) and (n_y, f1, f2, ...) shaped tensors, and the output is a - (n_x + n_y, n_x + n_y) shaped tensor containing the pairwise kernel values. + """Computes kernels for all pairs between and within ``x`` and ``y``. + + In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk) and (n_y, f1, f2, ..., fk)-shaped + tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing the pairwise + kernel values. Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor. + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. @@ -64,13 +66,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) - # Concatenate along batch dimension xy = torch.cat([x, y], dim=0) return self._kernel(xy) class RadialBasisFunction(Kernel): - """Radial basis function kernel. + """The radial basis function kernel. This kernel between two data points x and y is defined as :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth @@ -101,7 +102,8 @@ def __init__( @torch.no_grad() def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: - """ + """Heuristically determine a bandwidth parameter as the average distance between samples. + Computes the base bandwidth parameter as the average distance between samples if the bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking @@ -140,9 +142,8 @@ def _kernel(self, x: torch.Tensor) -> torch.Tensor: return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) -def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """ - Computes the maximum mean discrepancy (MMD) loss between two sets of samples x and y. +def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Computes the maximum mean discrepancy (MMD) loss between two sets of samples ``x`` and ``y``. This is a two-sample test to test the null hypothesis that the two samples are drawn from the same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is @@ -184,9 +185,8 @@ def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: return xx + yy - 2 * xy -class MMDLoss(nn.Module): - """ - Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of +class MaximumMeanDiscrepancy(nn.Module): + """Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of samples. This uses the `mmd_loss` function to compute the loss. @@ -201,8 +201,7 @@ def __init__(self, kernel: Kernel): self.kernel = kernel def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes the MMD loss between two sets of samples x and y. + """Computes the MMD loss between two sets of samples x and y. Args: x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. @@ -211,4 +210,4 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The computed MMD loss. """ - return mmd_loss(x, y, self.kernel) + return maximum_mean_discrepancy(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index d9b0bb0..40de4ab 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MMDLoss, RadialBasisFunction, mmd_loss +from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancy, RadialBasisFunction, maximum_mean_discrepancy from dwave.samplers import SimulatedAnnealingSampler @@ -163,10 +163,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MMDLoss(kernel) + mmd_loss_module = MaximumMeanDiscrepancy(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = mmd_loss(discretes, prior_samples, kernel) + mmd = maximum_mean_discrepancy(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From 63b15f834bd871a725c431b085e65d3f8ad157b7 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 18:18:09 +0000 Subject: [PATCH 03/12] Define kernel as function of two inputs --- dwave/plugins/torch/models/losses/mmd.py | 59 +++++++++++++----------- tests/test_dvae_winci2020.py | 6 +-- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py index 1fcb737..895a98f 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -20,7 +20,8 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", "maximum_mean_discrepancy", "MaximumMeanDiscrepancy"] +__all__ = ["Kernel", "RadialBasisFunction", + "maximum_mean_discrepancy_loss", "MaximumMeanDiscrepancyLoss"] class Kernel(nn.Module): @@ -34,17 +35,18 @@ class Kernel(nn.Module): """ @abstractmethod - def _kernel(self, x: torch.Tensor) -> torch.Tensor: + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) containing the pairwise kernel values. Args: - x (torch.Tensor): A (n, f1, f2, ..., fk) tensor. + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. Returns: - torch.Tensor: A (n, n) tensor. + torch.Tensor: A (nx, ny) tensor. """ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -66,8 +68,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) - xy = torch.cat([x, y], dim=0) - return self._kernel(xy) + return self._kernel(x, y) class RadialBasisFunction(Kernel): @@ -122,7 +123,7 @@ def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | flo return l2_distance_matrix.sum() / (num_samples**2 - num_samples) return self.bandwidth - def _kernel(self, x: torch.Tensor) -> torch.Tensor: + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the radial basis function kernel as @@ -132,22 +133,21 @@ def _kernel(self, x: torch.Tensor) -> torch.Tensor: where :math:`\sigma_i` are the bandwidths. Args: - x (torch.Tensor): A (n, f1, f2, ...) tensor. + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. Returns: - torch.Tensor: A (n, n) tensor representing the kernel matrix. + torch.Tensor: A (nx, ny) tensor representing the kernel matrix. """ - distance_matrix = torch.cdist(x, x, p=2) + distance_matrix = torch.cdist(x, y, p=2) bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) -def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """Computes the maximum mean discrepancy (MMD) loss between two sets of samples ``x`` and ``y``. +def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. - This is a two-sample test to test the null hypothesis that the two samples are drawn from the - same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is - defined as + The `squared MMD `_ is defined as .. math:: MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, @@ -160,22 +160,25 @@ def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) - .. math:: E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. - If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. In machine learning applications, the MMD can be - used as a loss function to compare the distribution of model-generated samples to the - distribution of real data samples to force model-generated samples to match the real data - distribution. + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss + function for minimizing the distance between the model distribution and data distribution. + + For more information, see + Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). + A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. kernel (Kernel): A kernel function object. Returns: - torch.Tensor: The computed MMD loss. + torch.Tensor: The squared maximum mean discrepancy estimate. """ num_x = x.shape[0] num_y = y.shape[0] - kernel_matrix = kernel(x, y) + xy = torch.cat([x, y], dim=0) + kernel_matrix = kernel(xy, xy) kernel_xx = kernel_matrix[:num_x, :num_x] kernel_yy = kernel_matrix[num_x:, num_x:] kernel_xy = kernel_matrix[:num_x, num_x:] @@ -185,11 +188,11 @@ def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) - return xx + yy - 2 * xy -class MaximumMeanDiscrepancy(nn.Module): - """Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of - samples. +class MaximumMeanDiscrepancyLoss(nn.Module): + """An unbiased estimator for the squared maximum mean discrepancy. - This uses the `mmd_loss` function to compute the loss. + This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to + compute the loss. Args: kernel (Kernel): A kernel function object. @@ -210,4 +213,4 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The computed MMD loss. """ - return maximum_mean_discrepancy(x, y, self.kernel) + return maximum_mean_discrepancy_loss(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 40de4ab..47e7613 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancy, RadialBasisFunction, maximum_mean_discrepancy +from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancyLoss, RadialBasisFunction, maximum_mean_discrepancy_loss from dwave.samplers import SimulatedAnnealingSampler @@ -163,10 +163,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MaximumMeanDiscrepancy(kernel) + mmd_loss_module = MaximumMeanDiscrepancyLoss(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = maximum_mean_discrepancy(discretes, prior_samples, kernel) + mmd = maximum_mean_discrepancy_loss(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From e37dc2c164084502518ddee040741c02378a0a5c Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 18:36:04 +0000 Subject: [PATCH 04/12] Refactor MMD into kernels, functional, and loss --- dwave/plugins/torch/nn/functional.py | 71 ++++++++++++++ .../losses/mmd.py => nn/modules/kernels.py} | 92 ++----------------- dwave/plugins/torch/nn/modules/loss.py | 56 +++++++++++ tests/test_dvae_winci2020.py | 15 +-- 4 files changed, 145 insertions(+), 89 deletions(-) create mode 100755 dwave/plugins/torch/nn/functional.py rename dwave/plugins/torch/{models/losses/mmd.py => nn/modules/kernels.py} (59%) create mode 100755 dwave/plugins/torch/nn/modules/loss.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py new file mode 100755 index 0000000..3399632 --- /dev/null +++ b/dwave/plugins/torch/nn/functional.py @@ -0,0 +1,71 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functional interface.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +import torch + +__all__ = ["maximum_mean_discrepancy_loss"] + + +def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. + + The `squared MMD `_ is defined as + + .. math:: + MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + + where :math:`\varphi` is a feature map associated with the kernel function + :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the + distributions of the samples. It follows that, in terms of the kernel function, the squared MMD + can be computed as + + .. math:: + E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. + + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss + function for minimizing the distance between the model distribution and data distribution. + + For more information, see + Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). + A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. + kernel (Kernel): A kernel function object. + + Returns: + torch.Tensor: The squared maximum mean discrepancy estimate. + """ + num_x = x.shape[0] + num_y = y.shape[0] + xy = torch.cat([x, y], dim=0) + kernel_matrix = kernel(xy, xy) + kernel_xx = kernel_matrix[:num_x, :num_x] + kernel_yy = kernel_matrix[num_x:, num_x:] + kernel_xy = kernel_matrix[:num_x, num_x:] + xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) + yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) + xy = kernel_xy.sum() / (num_x * num_y) + return xx + yy - 2 * xy + +torch.nn.MSELoss \ No newline at end of file diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/nn/modules/kernels.py similarity index 59% rename from dwave/plugins/torch/models/losses/mmd.py rename to dwave/plugins/torch/nn/modules/kernels.py index 895a98f..e13ab1d 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Kernel functions.""" from abc import abstractmethod from typing import Optional @@ -20,18 +21,17 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", - "maximum_mean_discrepancy_loss", "MaximumMeanDiscrepancyLoss"] +__all__ = ["Kernel", "RadialBasisFunction"] class Kernel(nn.Module): """Base class for kernels. - Kernels are functions that compute a similarity measure between data points. Any ``Kernel`` - subclass must implement the ``_kernel`` method, which computes the kernel matrix for a given - input multi-dimensional tensor with shape (n, f1, f2, ...), where n is the number of items - and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) - containing the pairwise kernel values. + `Kernels `_ are functions that compute a similarity + measure between data points. Any ``Kernel`` subclass must implement the ``_kernel`` method, + which computes the kernel matrix for a given input multi-dimensional tensor with shape + (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that + the output is a tensor of shape (n, n) containing the pairwise kernel values. """ @abstractmethod @@ -52,9 +52,9 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes kernels for all pairs between and within ``x`` and ``y``. - In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk) and (n_y, f1, f2, ..., fk)-shaped - tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing the pairwise - kernel values. + In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk)- and (n_y, f1, f2, ..., fk)-shaped + tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing pairwise kernel + evaluations. Args: x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. @@ -142,75 +142,3 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: distance_matrix = torch.cdist(x, y, p=2) bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) - - -def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. - - The `squared MMD `_ is defined as - - .. math:: - MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, - - where :math:`\varphi` is a feature map associated with the kernel function - :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the - distributions of the samples. It follows that, in terms of the kernel function, the squared MMD - can be computed as - - .. math:: - E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. - - If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss - function for minimizing the distance between the model distribution and data distribution. - - For more information, see - Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). - A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. - - Args: - x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. - kernel (Kernel): A kernel function object. - - Returns: - torch.Tensor: The squared maximum mean discrepancy estimate. - """ - num_x = x.shape[0] - num_y = y.shape[0] - xy = torch.cat([x, y], dim=0) - kernel_matrix = kernel(xy, xy) - kernel_xx = kernel_matrix[:num_x, :num_x] - kernel_yy = kernel_matrix[num_x:, num_x:] - kernel_xy = kernel_matrix[:num_x, num_x:] - xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) - yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) - xy = kernel_xy.sum() / (num_x * num_y) - return xx + yy - 2 * xy - - -class MaximumMeanDiscrepancyLoss(nn.Module): - """An unbiased estimator for the squared maximum mean discrepancy. - - This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to - compute the loss. - - Args: - kernel (Kernel): A kernel function object. - """ - - @store_config - def __init__(self, kernel: Kernel): - super().__init__() - self.kernel = kernel - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Computes the MMD loss between two sets of samples x and y. - - Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. - - Returns: - torch.Tensor: The computed MMD loss. - """ - return maximum_mean_discrepancy_loss(x, y, self.kernel) diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py new file mode 100755 index 0000000..eed4b28 --- /dev/null +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -0,0 +1,56 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.utils import store_config + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +__all__ = ["MaximumMeanDiscrepancyLoss"] + + +class MaximumMeanDiscrepancyLoss(nn.Module): + """An unbiased estimator for the squared maximum mean discrepancy as a loss function. + + This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to + compute the loss. + + Args: + kernel (Kernel): A kernel function object. + """ + + @store_config + def __init__(self, kernel: Kernel): + super().__init__() + self.kernel = kernel + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the MMD loss between two sets of samples x and y. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + + Returns: + torch.Tensor: The computed MMD loss. + """ + return mmd_loss(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 47e7613..cc11090 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -19,11 +19,12 @@ from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine -from dwave.plugins.torch.models.discrete_variational_autoencoder import ( - DiscreteVariationalAutoencoder as DVAE, -) +from dwave.plugins.torch.models.discrete_variational_autoencoder import \ + DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancyLoss, RadialBasisFunction, maximum_mean_discrepancy_loss +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler @@ -139,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RadialBasisFunction(num_features=5, mul_factor=2.0, bandwidth=None) + kernel = RBF(num_features=5, mul_factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) @@ -163,10 +164,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MaximumMeanDiscrepancyLoss(kernel) + mmd_loss_module = MMDLoss(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = maximum_mean_discrepancy_loss(discretes, prior_samples, kernel) + mmd = mmd_loss(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From 506829d285532c1d120203a972a23c4c3869b5fd Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 19:39:43 +0000 Subject: [PATCH 05/12] Add unit tests for kernels --- dwave/plugins/torch/nn/modules/kernels.py | 38 ++++---- tests/test_dvae_winci2020.py | 2 +- tests/test_kernels.py | 102 ++++++++++++++++++++++ 3 files changed, 121 insertions(+), 21 deletions(-) create mode 100755 tests/test_kernels.py diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index e13ab1d..aed869a 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -14,7 +14,6 @@ """Kernel functions.""" from abc import abstractmethod -from typing import Optional import torch import torch.nn as nn @@ -84,25 +83,25 @@ class RadialBasisFunction(Kernel): average distance between samples. Args: - num_features (int): Number of kernel bandwidths to use. - mul_factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are - computed as :math:`\sigma_i = \sigma * mul\_factor^{i - num\_features // 2}` for - :math:`i` in ``[0, num_features - 1]``. Defaults to 2.0. - bandwidth (float | None): Base bandwidth parameter. If None, the bandwidth is estimated - from the data. Defaults to None. + n_kernels (int): Number of kernel bandwidths to use. + factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are + computed as :math:`\sigma_i = \sigma * factor^{i - n\_kernels // 2}` for + :math:`i` in ``[0, n\_kernels - 1]``. Defaults to 2.0. + bandwidth (float | None): Base bandwidth parameter. If ``None``, the bandwidth is computed + from the data (without gradients). Defaults to ``None``. """ @store_config def __init__( - self, num_features: int, mul_factor: int | float = 2.0, bandwidth: Optional[float] = None + self, n_kernels: int, factor: int | float = 2.0, bandwidth: float | None = None ): super().__init__() - bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2) - self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) + factors = factor ** (torch.arange(n_kernels) - n_kernels // 2) + self.register_buffer("factors", factors) self.bandwidth = bandwidth @torch.no_grad() - def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: + def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: """Heuristically determine a bandwidth parameter as the average distance between samples. Computes the base bandwidth parameter as the average distance between samples if the @@ -111,21 +110,20 @@ def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | flo the average distance as the bandwidth. Args: - l2_distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise - L2 distances between samples. If it is None and the bandwidth is not provided, an - error will be raised. Defaults to None. + distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise + L2 distances between samples. If it is ``None`` and the bandwidth is not provided, + an error will be raised. Defaults to ``None``. Returns: torch.Tensor | float: The base bandwidth parameter. """ if self.bandwidth is None: - num_samples = l2_distance_matrix.shape[0] - return l2_distance_matrix.sum() / (num_samples**2 - num_samples) + num_samples = distance_matrix.shape[0] + return distance_matrix.sum() / (num_samples**2 - num_samples) return self.bandwidth def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes the radial basis function kernel as + """Compute the radial basis function kernel between ``x`` and ``y``. .. math:: k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), @@ -139,6 +137,6 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: A (nx, ny) tensor representing the kernel matrix. """ - distance_matrix = torch.cdist(x, y, p=2) - bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers + distance_matrix = torch.cdist(x.flatten(1), y.flatten(1), p=2) + bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.factors return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index cc11090..eced7dd 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -140,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RBF(num_features=5, mul_factor=2.0, bandwidth=None) + kernel = RBF(n_kernels=5, factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100755 index 0000000..a507885 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,102 @@ +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF + + +class TestKernel(unittest.TestCase): + def test_forward(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 3)) + y = torch.randn((9, 3)) + self.assertEqual(1, one(x, y)) + + def test_shape_mismatch(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 4)) + y = torch.randn((9, 3)) + self.assertRaises(ValueError, one, x, y) + + +class TestRadialBasisFunction(unittest.TestCase): + + def test_has_config(self): + rbf = RBF(5, 2.1, 0.1) + self.assertDictEqual(dict(rbf.config), dict(module_name="RadialBasisFunction", + n_kernels=5, factor=2.1, bandwidth=0.1)) + + @parameterized.expand([ + (torch.randn((5, 12)), torch.rand((7, 12))), + (torch.randn((5, 12, 34)), torch.rand((7, 12, 34))), + ]) + def test_shape(self, x, y): + rbf = RBF(2, 2.1, 0.1) + k = rbf(x, y) + self.assertEqual(tuple(k.shape), (x.shape[0], y.shape[0])) + + def test_get_bandwidth_default(self): + rbf = RBF(2, 2.1, 0.1) + d = torch.tensor(123) + self.assertEqual(0.1, rbf._get_bandwidth(d)) + + def test_get_bandwidth(self): + rbf = RBF(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]]) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + + def test_get_bandwidth_no_grad(self): + rbf = RBF(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]], requires_grad=True) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + self.assertIsNone(rbf._get_bandwidth(d).grad) + + def test_single_factors(self): + rbf = RBF(1, 2.1, None) + self.assertListEqual(rbf.factors.tolist(), [1.0]) + + def test_two_factors(self): + rbf = RBF(2, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1]), rbf.factors) + + def test_three_factors(self): + rbf = RBF(3, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1, 2.1]), rbf.factors) + + def test_kernel(self): + x = torch.tensor([[1.0, 1.0], + [2.0, 3.0]], requires_grad=True) + y = torch.tensor([[0.0, 1.0], + [-3.0, 5.0], + [1.2, 9.0]], requires_grad=True) + dist = torch.cdist(x, y) + + with self.subTest("Adaptive bandwidth"): + rbf = RBF(1, 2.1, None) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Simple bandwidth"): + rbf = RBF(1, 2.1, 12.34) + bandwidths = 12.34 * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Multiple kernels"): + rbf = RBF(3, 2.1, 123) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths.reshape(-1, 1, 1)).sum(0) + torch.testing.assert_close(manual, rbf(x, y)) + + +if __name__ == "__main__": + unittest.main() From cd815a079e8de0603f9b679b807b1ad6d38fd8c5 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 23:19:33 +0000 Subject: [PATCH 06/12] Add tests for functional and loss add errors --- dwave/plugins/torch/nn/functional.py | 26 +++++++- dwave/plugins/torch/nn/modules/kernels.py | 7 +- tests/test_functional.py | 80 +++++++++++++++++++++++ tests/test_kernels.py | 4 +- tests/test_loss.py | 47 +++++++++++++ tests/test_nn.py | 13 ++++ 6 files changed, 170 insertions(+), 7 deletions(-) create mode 100755 tests/test_functional.py create mode 100755 tests/test_loss.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index 3399632..b8b250b 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -25,13 +25,21 @@ __all__ = ["maximum_mean_discrepancy_loss"] +class SampleSizeError(ValueError): + pass + + +class DimensionMismatchError(ValueError): + pass + + def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. The `squared MMD `_ is defined as .. math:: - MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + MMD^2(X, Y) = |E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] |^2, where :math:`\varphi` is a feature map associated with the kernel function :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the @@ -53,11 +61,25 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. kernel (Kernel): A kernel function object. + Raises: + SampleSizeError: If the sample size of ``x`` or ``y`` is less than two. + DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + Returns: torch.Tensor: The squared maximum mean discrepancy estimate. """ num_x = x.shape[0] num_y = y.shape[0] + if num_x < 2 or num_y < 2: + raise SampleSizeError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) + if x.shape[1:] != y.shape[1:]: + raise DimensionMismatchError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) xy = torch.cat([x, y], dim=0) kernel_matrix = kernel(xy, xy) kernel_xx = kernel_matrix[:num_x, :num_x] @@ -67,5 +89,3 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) xy = kernel_xy.sum() / (num_x * num_y) return xx + yy - 2 * xy - -torch.nn.MSELoss \ No newline at end of file diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index aed869a..61e9c21 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config __all__ = ["Kernel", "RadialBasisFunction"] @@ -32,7 +33,6 @@ class Kernel(nn.Module): (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) containing the pairwise kernel values. """ - @abstractmethod def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. @@ -59,11 +59,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. + Raises: + DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. """ if x.shape[1:] != y.shape[1:]: - raise ValueError( + raise DimensionMismatchError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100755 index 0000000..da85182 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,80 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from dwave.plugins.torch.nn.functional import SampleSizeError +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + def test_mmd_loss_constant(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3], [0.5]]) + + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + # kxx = (4 + 4)/2 + # kyy = (3 + 3)/2 + # kxy = (0 + 1 + 4 + 2)/4 + # kxx + kyy -2kxy = 4 + 3 - 3.5 = 3.5 + self.assertEqual(3.5, mmd_loss(x, y, kernel)) + + def test_sample_size_error(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3]]) + self.assertRaises(SampleSizeError, mmd_loss, x, y, None) + + def test_mmd_loss_dim_mismatch(self): + x = torch.tensor([[1], [4]], dtype=torch.float32) + y = torch.tensor([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.6]]) + self.assertRaises(DimensionMismatchError, mmd_loss, x, y, None) + + def test_mmd_loss_arange(self): + x = torch.tensor([[1.0], [4.0], [5.0]]) + y = torch.tensor([[0.3], [0.4]]) + + class Constant(Kernel): + def _kernel(self, x, y): + return torch.tensor([[150, 22, 39, 34, 28], + [22, 630, 98, 56, 44], + [39, 98, 560, 78, 33], + [-99, -99, -99, 299, 13], + [-99, -99, -99, 13, 970]], dtype=torch.float32) + + mmd_loss(x, y, Constant()) + # NOTE: calculation takes kxy = upper-right corner; no PSD assumption + # kxx = (22+39+98)/3 + # kyy = 13 + # kxy = (34+28+56+44+78+33)/6 + # kxx + kyy - 2*kxy + # kxx + kyy - 2*kxy = -25.0 + self.assertEqual(-25, mmd_loss(x, y, Constant())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_kernels.py b/tests/test_kernels.py index a507885..7b29d72 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF @@ -24,7 +24,7 @@ def _kernel(self, x, y): one = One() x = torch.rand((5, 4)) y = torch.randn((9, 3)) - self.assertRaises(ValueError, one, x, y) + self.assertRaises(DimensionMismatchError, one, x, y) class TestRadialBasisFunction(unittest.TestCase): diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100755 index 0000000..e59dbe3 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,47 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + @parameterized.expand([ + (torch.tensor([[1.2], [4.1]]), torch.tensor([[0.3], [0.5]])), + (torch.randn((123, 4, 3, 2)), torch.rand(100, 4, 3, 2)), + ]) + def test_mmd_loss(self, x, y): + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + compute_mmd = MMDLoss(kernel) + torch.testing.assert_close(mmd_loss(x, y, kernel), compute_mmd(x, y)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..bac40c9 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,3 +1,16 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest import torch From 901f3339e3e3d2f087018161bf799eca7b673461 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 23:45:26 +0000 Subject: [PATCH 07/12] Update release note --- .../add-mmd-loss-function-3fa9e9a2cb452391.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml index a431a7d..46ea631 100644 --- a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml +++ b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml @@ -1,10 +1,10 @@ --- features: - | - MMD loss is available in ``dwave.plugins.torch.models.losses.mmd.mmd_loss``, - which computes the MMD loss using a ``dwave.plugins.torch.models.losses.mmd.Kernel`` - (specialized to the ``dwave.plugins.torch.models.losses.mmd.RBFKernel``). This - enables training encoders in discrete variational autoencoders to match the - distribution of the prior model. + Add a ``MaximumMeanDiscrepancyLoss`` in ``dwave.plugins.torch.nn.loss`` for estimating the + squared maximum mean discrepancy (MMD) for a given kernel and two samples. + Its functional counterpart ``maximum_mean_discrepancy_loss`` is in + ``dwave.plugins.torch.nn.functional``. + Kernels reside in ``dwave.plugins.torch.nn.modules.kernels``. This enables, for example, + training discrete autoencoders to match the distribution of a target distribution (e.g., prior). - \ No newline at end of file From 5d4d473eb7d4eb6ed41db226eb92b6c8c6cad8b8 Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 17 Dec 2025 18:49:33 +0000 Subject: [PATCH 08/12] Rename RBF to GaussianKernel --- dwave/plugins/torch/nn/modules/kernels.py | 10 ++++---- tests/test_dvae_winci2020.py | 2 +- tests/test_kernels.py | 29 +++++++++++------------ 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 61e9c21..2d614d5 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -21,7 +21,7 @@ from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction"] +__all__ = ["Kernel", "GaussianKernel"] class Kernel(nn.Module): @@ -73,14 +73,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self._kernel(x, y) -class RadialBasisFunction(Kernel): - """The radial basis function kernel. +class GaussianKernel(Kernel): + """The Gaussian kernel. This kernel between two data points x and y is defined as :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth parameter. - This implementation considers aggregating multiple radial basis function kernels with different + This implementation considers aggregating multiple Gaussian kernels with different bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of multipliers. The base bandwidth can be provided directly or estimated from the data using the average distance between samples. @@ -126,7 +126,7 @@ def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: return self.bandwidth def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Compute the radial basis function kernel between ``x`` and ``y``. + """Compute the Gaussian kernel between ``x`` and ``y``. .. math:: k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index eced7dd..70b6543 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel as RBF from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 7b29d72..fd106af 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,8 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel -from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, GaussianKernel, Kernel class TestKernel(unittest.TestCase): @@ -27,11 +26,11 @@ def _kernel(self, x, y): self.assertRaises(DimensionMismatchError, one, x, y) -class TestRadialBasisFunction(unittest.TestCase): +class TestGaussianKernel(unittest.TestCase): def test_has_config(self): - rbf = RBF(5, 2.1, 0.1) - self.assertDictEqual(dict(rbf.config), dict(module_name="RadialBasisFunction", + rbf = GaussianKernel(5, 2.1, 0.1) + self.assertDictEqual(dict(rbf.config), dict(module_name="GaussianKernel", n_kernels=5, factor=2.1, bandwidth=0.1)) @parameterized.expand([ @@ -39,36 +38,36 @@ def test_has_config(self): (torch.randn((5, 12, 34)), torch.rand((7, 12, 34))), ]) def test_shape(self, x, y): - rbf = RBF(2, 2.1, 0.1) + rbf = GaussianKernel(2, 2.1, 0.1) k = rbf(x, y) self.assertEqual(tuple(k.shape), (x.shape[0], y.shape[0])) def test_get_bandwidth_default(self): - rbf = RBF(2, 2.1, 0.1) + rbf = GaussianKernel(2, 2.1, 0.1) d = torch.tensor(123) self.assertEqual(0.1, rbf._get_bandwidth(d)) def test_get_bandwidth(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]]) self.assertEqual(3.4, rbf._get_bandwidth(d)) def test_get_bandwidth_no_grad(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]], requires_grad=True) self.assertEqual(3.4, rbf._get_bandwidth(d)) self.assertIsNone(rbf._get_bandwidth(d).grad) def test_single_factors(self): - rbf = RBF(1, 2.1, None) + rbf = GaussianKernel(1, 2.1, None) self.assertListEqual(rbf.factors.tolist(), [1.0]) def test_two_factors(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) torch.testing.assert_close(torch.tensor([2.1**-1, 1]), rbf.factors) def test_three_factors(self): - rbf = RBF(3, 2.1, None) + rbf = GaussianKernel(3, 2.1, None) torch.testing.assert_close(torch.tensor([2.1**-1, 1, 2.1]), rbf.factors) def test_kernel(self): @@ -80,19 +79,19 @@ def test_kernel(self): dist = torch.cdist(x, y) with self.subTest("Adaptive bandwidth"): - rbf = RBF(1, 2.1, None) + rbf = GaussianKernel(1, 2.1, None) bandwidths = rbf._get_bandwidth(dist) * rbf.factors manual = torch.exp(-dist/bandwidths) torch.testing.assert_close(manual, rbf(x, y)) with self.subTest("Simple bandwidth"): - rbf = RBF(1, 2.1, 12.34) + rbf = GaussianKernel(1, 2.1, 12.34) bandwidths = 12.34 * rbf.factors manual = torch.exp(-dist/bandwidths) torch.testing.assert_close(manual, rbf(x, y)) with self.subTest("Multiple kernels"): - rbf = RBF(3, 2.1, 123) + rbf = GaussianKernel(3, 2.1, 123) bandwidths = rbf._get_bandwidth(dist) * rbf.factors manual = torch.exp(-dist/bandwidths.reshape(-1, 1, 1)).sum(0) torch.testing.assert_close(manual, rbf(x, y)) From 4c0d233bc82cbf83f467b5082d158cb6b49e69f5 Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 17 Dec 2025 18:52:08 +0000 Subject: [PATCH 09/12] Renme RBF to GaussianKernel --- tests/test_dvae_winci2020.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 70b6543..01825b9 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import GaussianKernel as RBF +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler @@ -140,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RBF(n_kernels=5, factor=2.0, bandwidth=None) + kernel = GaussianKernel(n_kernels=5, factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) From 285b5c1a418108600763d6f01c9a47d5533939f1 Mon Sep 17 00:00:00 2001 From: kchern Date: Mon, 5 Jan 2026 18:09:50 +0000 Subject: [PATCH 10/12] Remove custom errors and fix docstrings Co-Authored-By: Theodor Isacsson --- dwave/plugins/torch/nn/functional.py | 15 ++++----------- dwave/plugins/torch/nn/modules/kernels.py | 14 +++++++++----- dwave/plugins/torch/nn/modules/loss.py | 4 ++-- tests/test_functional.py | 7 +++---- tests/test_kernels.py | 15 ++++++++++++--- 5 files changed, 30 insertions(+), 25 deletions(-) diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index b8b250b..8327f7a 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -25,13 +25,6 @@ __all__ = ["maximum_mean_discrepancy_loss"] -class SampleSizeError(ValueError): - pass - - -class DimensionMismatchError(ValueError): - pass - def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. @@ -62,8 +55,8 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern kernel (Kernel): A kernel function object. Raises: - SampleSizeError: If the sample size of ``x`` or ``y`` is less than two. - DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + ValueError: If the sample size of ``x`` or ``y`` is less than two. + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) Returns: torch.Tensor: The squared maximum mean discrepancy estimate. @@ -71,12 +64,12 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern num_x = x.shape[0] num_y = y.shape[0] if num_x < 2 or num_y < 2: - raise SampleSizeError( + raise ValueError( "Sample size of ``x`` and ``y`` must be at least two. " f"Got, respectively, {x.shape} and {y.shape}." ) if x.shape[1:] != y.shape[1:]: - raise DimensionMismatchError( + raise ValueError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 2d614d5..2aa6ced 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -13,18 +13,17 @@ # limitations under the License. """Kernel functions.""" -from abc import abstractmethod +from abc import ABC, abstractmethod import torch import torch.nn as nn -from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config __all__ = ["Kernel", "GaussianKernel"] -class Kernel(nn.Module): +class Kernel(ABC, nn.Module): """Base class for kernels. `Kernels `_ are functions that compute a similarity @@ -60,16 +59,21 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. Raises: - DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. """ if x.shape[1:] != y.shape[1:]: - raise DimensionMismatchError( + raise ValueError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) + if x.shape[0] < 2 or y.shape[0] < 2: + raise ValueError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) return self._kernel(x, y) diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py index eed4b28..5da41f5 100755 --- a/dwave/plugins/torch/nn/modules/loss.py +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -29,7 +29,7 @@ class MaximumMeanDiscrepancyLoss(nn.Module): - """An unbiased estimator for the squared maximum mean discrepancy as a loss function. + """An unbiased estimator for the squared maximum mean discrepancy (MMD) as a loss function. This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to compute the loss. @@ -44,7 +44,7 @@ def __init__(self, kernel: Kernel): self.kernel = kernel def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Computes the MMD loss between two sets of samples x and y. + """Computes the MMD loss between two sets of samples ``x`` and ``y``. Args: x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. diff --git a/tests/test_functional.py b/tests/test_functional.py index da85182..17f81f4 100755 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -15,9 +15,8 @@ import torch -from dwave.plugins.torch.nn.functional import SampleSizeError from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel +from dwave.plugins.torch.nn.modules.kernels import Kernel class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): @@ -46,13 +45,13 @@ def _kernel(self, x, y): def test_sample_size_error(self): x = torch.tensor([[1.2], [4.1]]) y = torch.tensor([[0.3]]) - self.assertRaises(SampleSizeError, mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, "must be at least two", mmd_loss, x, y, None) def test_mmd_loss_dim_mismatch(self): x = torch.tensor([[1], [4]], dtype=torch.float32) y = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) - self.assertRaises(DimensionMismatchError, mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, "Input dimensions must match. You are trying to compute ", mmd_loss, x, y, None) def test_mmd_loss_arange(self): x = torch.tensor([[1.0], [4.0], [5.0]]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index fd106af..278ca53 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, GaussianKernel, Kernel +from dwave.plugins.torch.nn.modules.kernels import Kernel, GaussianKernel class TestKernel(unittest.TestCase): @@ -16,6 +16,16 @@ def _kernel(self, x, y): y = torch.randn((9, 3)) self.assertEqual(1, one(x, y)) + @parameterized.expand([(1, 2), (2, 1)]) + def test_sample_size(self, nx, ny): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((nx, 5)) + y = torch.randn((ny, 5)) + self.assertRaisesRegex(ValueError, "must be at least two", one, x, y) + def test_shape_mismatch(self): class One(Kernel): def _kernel(self, x, y): @@ -23,8 +33,7 @@ def _kernel(self, x, y): one = One() x = torch.rand((5, 4)) y = torch.randn((9, 3)) - self.assertRaises(DimensionMismatchError, one, x, y) - + self.assertRaisesRegex(ValueError, "Input dimensions must match", one, x, y) class TestGaussianKernel(unittest.TestCase): From 261cd4da177fcb381dba767ad94f995cb2d27b4d Mon Sep 17 00:00:00 2001 From: kchern Date: Mon, 5 Jan 2026 18:20:11 +0000 Subject: [PATCH 11/12] Fix a docstring --- dwave/plugins/torch/nn/modules/kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 2aa6ced..5dfaf4e 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -36,7 +36,7 @@ class Kernel(ABC, nn.Module): def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. - Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) + Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and (ny, f1, f2, ..., fk), whose shape is (nx, ny) containing the pairwise kernel values. Args: From 975275836ebc04adb6fea41f3193d072ed122ba6 Mon Sep 17 00:00:00 2001 From: Kevin Chern <32395608+kevinchern@users.noreply.github.com> Date: Tue, 6 Jan 2026 10:23:11 -0800 Subject: [PATCH 12/12] Fix minor code aesthetics Co-authored-by: Theodor Isacsson --- dwave/plugins/torch/nn/modules/kernels.py | 4 +++- dwave/plugins/torch/nn/modules/loss.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 5dfaf4e..6119cb5 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -32,11 +32,13 @@ class Kernel(ABC, nn.Module): (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) containing the pairwise kernel values. """ + @abstractmethod def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. - Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and (ny, f1, f2, ..., fk), whose shape is (nx, ny) + Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and + (ny, f1, f2, ..., fk), whose shape is (nx, ny) containing the pairwise kernel values. Args: diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py index 5da41f5..45488ec 100755 --- a/dwave/plugins/torch/nn/modules/loss.py +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -39,7 +39,7 @@ class MaximumMeanDiscrepancyLoss(nn.Module): """ @store_config - def __init__(self, kernel: Kernel): + def __init__(self, kernel: Kernel) -> None: super().__init__() self.kernel = kernel