From 21bc015a32424011801e4ed7fbf6a61107505397 Mon Sep 17 00:00:00 2001 From: amansouribigvand Date: Tue, 7 Apr 2026 16:10:37 +0000 Subject: [PATCH 1/2] Updating Dimod Sampler --- dwave/plugins/torch/samplers/dimod_sampler.py | 88 ++++++++++++++++--- tests/test_samplers/test_dimod_sampler.py | 47 +++++++++- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index 1e2c992..ebc394b 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any import torch -from dimod import Sampler +import dimod from hybrid.composers import AggregatedSamples from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine @@ -93,18 +93,84 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will be sampled; entries with +/-1 values will remain constant. """ - if x is not None: - raise NotImplementedError("Support for conditional sampling has not been implemented.") - - h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) - self._sample_set = AggregatedSamples.spread( - self._sampler.sample_ising(h, J, **self._sampler_params) - ) - - # use same device as modules linear device = self._grbm._linear.device - return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device) + nodes = self._grbm.nodes + n_nodes = self._grbm.n_nodes + h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) + + # Unconditional sampling + if x is None: + self._sample_set = AggregatedSamples.spread( + self._sampler.sample_ising(h, J, **self._sampler_params) + ) + return sampleset_to_tensor(nodes, self._sample_set, device) + + # Conditional sampling + if x.shape[1] != n_nodes: + raise ValueError(f"x must have shape (batch_size, {n_nodes})") + + mask = ~torch.isnan(x) + if not torch.all(torch.isin(x[mask], torch.tensor([-1, 1], device=x.device))): + raise ValueError("x must contain only ±1 or NaN") + + results = [] + for i in range(x.shape[0]): + # Fresh BQM + bqm = dimod.BinaryQuadraticModel.from_ising(h, J) + + # Build conditioning dict + conditioned = {node: int(x[i, j].item()) + for j, node in enumerate(nodes) if mask[i, j]} + + # Apply conditioning + if conditioned: + bqm.fix_variables(conditioned) + + # Clip linear biases for remaining free variables + if self._linear_range is not None: + lb, ub = self._linear_range + for v in bqm.linear: + if bqm.linear[v] > ub: + bqm.set_linear(v, ub) + elif bqm.linear[v] < lb: + bqm.set_linear(v, lb) + + # Clip quadratic biases + if self._quadratic_range is not None: + lb, ub = self._quadratic_range + for u, v, bias in bqm.iter_quadratic(): + if bias > ub: + bqm.set_quadratic(u, v, ub) + elif bias < lb: + bqm.set_quadratic(u, v, lb) + + # Handle fully clamped case + if bqm.num_variables == 0: + full = torch.tensor([conditioned[node] for node in nodes], + device=device, dtype=torch.float) + results.append(full) + continue + + # Sample one configuration per input + sample_kwargs = dict(self._sampler_params) + sample_kwargs["num_reads"] = 1 + sample_set = AggregatedSamples.spread( + self._sampler.sample(bqm, **sample_kwargs) + ) + + # Extract sampled values + sample = sample_set.first.sample + + # Reconstruct full sample + full = torch.empty(n_nodes, device=device) + for j, node in enumerate(nodes): + full[j] = conditioned[node] if node in conditioned else float(sample[node]) + results.append(full) + + # Stack to get (batch_size, n_nodes) + return torch.stack(results, dim=0) + @property def sample_set(self) -> dimod.SampleSet: """The sample set returned from the latest sample call.""" diff --git a/tests/test_samplers/test_dimod_sampler.py b/tests/test_samplers/test_dimod_sampler.py index 59e7b88..4c365bd 100644 --- a/tests/test_samplers/test_dimod_sampler.py +++ b/tests/test_samplers/test_dimod_sampler.py @@ -22,7 +22,7 @@ from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler from dwave.samplers import SteepestDescentSampler from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple - +from dwave.samplers import SimulatedAnnealingSampler class TestDimodSampler(unittest.TestCase): def setUp(self) -> None: @@ -125,6 +125,51 @@ def test_sample(self): torch.tensor(list(tracker.input['J'].values())), torch.tensor([0, 0, 0, 0.0]) ) + + with self.subTest("Conditional sampling preserves clamped variables"): + sampler = DimodSampler( + self.bm, + SimulatedAnnealingSampler(), + prefactor=1, + sample_kwargs=dict(num_reads=1) + ) + + x = torch.tensor([ + [1.0, float("nan"), -1.0, float("nan")], + [float("nan"), -1.0, float("nan"), 1.0], + ]) + + samples = sampler.sample(x) + + # Shape check + self.assertTupleEqual(samples.shape, x.shape) + + # Check clamped values unchanged + mask = ~torch.isnan(x) + self.assertTrue(torch.all(samples[mask] == x[mask])) + + # Check free variables are ±1 + free_mask = torch.isnan(x) + free_values = samples[free_mask] + self.assertTrue(torch.all(torch.isin(free_values, torch.tensor([-1.0, 1.0]))), + "Free variables should be sampled as ±1") + + with self.subTest("Conditional sampling with all variables clamped returns input unchanged."): + sampler = DimodSampler( + self.bm, + SimulatedAnnealingSampler(), + prefactor=1, + sample_kwargs=dict(num_reads=1) + ) + + x = torch.tensor([ + [+1.0, -1.0, -1.0, +1.0], + [-1.0, +1.0, -1.0, -1.0], + ]) + + samples = sampler.sample(x) + # All spins clamped, should return identical tensor + torch.testing.assert_close(samples, x) def test_sample_set(self): grbm = GRBM(list("abcd"), [("a", "b")]) From 0cf8a8977f828f5ca1f59cab7cd5f9d7a2f94867 Mon Sep 17 00:00:00 2001 From: amansouribigvand Date: Tue, 14 Apr 2026 21:31:15 +0000 Subject: [PATCH 2/2] add test, release note and improve sample --- dwave/plugins/torch/samplers/dimod_sampler.py | 65 +++++++++++++------ ...conditional-sampling-3ce98d4847eedb83.yaml | 5 ++ tests/test_samplers/test_dimod_sampler.py | 30 ++++++++- 3 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index ebc394b..688a6d7 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -17,7 +17,8 @@ import torch import dimod -from hybrid.composers import AggregatedSamples +import warnings +from hybrid.composers import AggregatedSamples from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine from dwave.plugins.torch.samplers.base import TorchSampler @@ -86,12 +87,20 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: """Sample from the dimod sampler and return the corresponding tensor. The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` - which is overwritten by subsequent calls. + which is overwritten by subsequent calls. When ``x`` is provided (conditional sampling), exactly + one sample is drawn per input row. Any user-specified ``num_reads`` in ``sample_kwargs`` will be + ignored and overridden to 1. Args: x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will be sampled; entries with +/-1 values will remain constant. + Raises: + ValueError: If ``x`` has an invalid shape or contains values other than ±1 or NaN. + + Returns: + torch.Tensor: A tensor of shape (``batch_size``, ``n_nodes``) containing + sampled spin configurations with values in ``{-1, +1}``. """ device = self._grbm._linear.device nodes = self._grbm.nodes @@ -115,18 +124,26 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: raise ValueError("x must contain only ±1 or NaN") results = [] - for i in range(x.shape[0]): + for row, row_mask in zip(x, mask): # Fresh BQM bqm = dimod.BinaryQuadraticModel.from_ising(h, J) - + # Build conditioning dict - conditioned = {node: int(x[i, j].item()) - for j, node in enumerate(nodes) if mask[i, j]} - + conditioned = {node: int(val.item()) + for node, val, m in zip(nodes, row, row_mask) if m + } + # Apply conditioning if conditioned: bqm.fix_variables(conditioned) + # Handle fully clamped case + if bqm.num_variables == 0: + full = torch.tensor([conditioned[node] for node in nodes], + device=device, dtype=torch.float) + results.append(full) + continue + # Clip linear biases for remaining free variables if self._linear_range is not None: lb, ub = self._linear_range @@ -145,20 +162,18 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: elif bias < lb: bqm.set_quadratic(u, v, lb) - # Handle fully clamped case - if bqm.num_variables == 0: - full = torch.tensor([conditioned[node] for node in nodes], - device=device, dtype=torch.float) - results.append(full) - continue - # Sample one configuration per input sample_kwargs = dict(self._sampler_params) - sample_kwargs["num_reads"] = 1 - sample_set = AggregatedSamples.spread( - self._sampler.sample(bqm, **sample_kwargs) - ) - + if "num_reads" in sample_kwargs and sample_kwargs["num_reads"] != 1: + warnings.warn( + "`num_reads` is ignored during conditional sampling and set to 1 " + "(one sample per input row).", + UserWarning, + ) + sample_kwargs["num_reads"] = 1 + + sample_set = self._sampler.sample(bqm, **sample_kwargs) + # Extract sampled values sample = sample_set.first.sample @@ -167,9 +182,17 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: for j, node in enumerate(nodes): full[j] = conditioned[node] if node in conditioned else float(sample[node]) results.append(full) - + # Stack to get (batch_size, n_nodes) - return torch.stack(results, dim=0) + samples = torch.stack(results, dim=0) + energies = self._grbm(samples).detach().cpu().numpy() + + self._sample_set = dimod.SampleSet.from_samples( + (samples.cpu().numpy(), self._grbm.nodes), + vartype=dimod.SPIN, + energy=energies, + ) + return samples @property def sample_set(self) -> dimod.SampleSet: diff --git a/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml b/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml new file mode 100644 index 0000000..347dc9c --- /dev/null +++ b/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml @@ -0,0 +1,5 @@ +--- +upgrade: + - | + Add conditional sampling functionality for the ``DimodSampler``. The sampler + enforces one sample per input row during conditional sampling. diff --git a/tests/test_samplers/test_dimod_sampler.py b/tests/test_samplers/test_dimod_sampler.py index 4c365bd..8adec34 100644 --- a/tests/test_samplers/test_dimod_sampler.py +++ b/tests/test_samplers/test_dimod_sampler.py @@ -15,14 +15,13 @@ import unittest import torch -from dimod import SPIN, BinaryQuadraticModel, IdentitySampler, SampleSet, TrackingComposite -from parameterized import parameterized +from dimod import IdentitySampler, SampleSet, TrackingComposite from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler from dwave.samplers import SteepestDescentSampler -from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple from dwave.samplers import SimulatedAnnealingSampler +from dwave.plugins.torch.utils import sampleset_to_tensor class TestDimodSampler(unittest.TestCase): def setUp(self) -> None: @@ -187,6 +186,31 @@ def test_sample_set(self): with self.subTest("The `sample_set` attribute should be of type `dimod.SampleSet`."): self.assertTrue(isinstance(sampler.sample_set, SampleSet)) + def test_order_sample(self): + + nodes = ["b", "a", "c"] + edges = [("a", "b"), ("b", "c")] + + grbm = GRBM(nodes, edges) + + sampler = DimodSampler( + grbm, + SimulatedAnnealingSampler(), + prefactor=1.0, + sample_kwargs=dict(num_reads=1) + ) + x = torch.tensor([ + [float("nan"), float("nan"), -1.0], + ]) + + out = sampler.sample(x) + ss = sampler.sample_set + + # Check alignment with sampleset_to_tensor + expected_samples = sampleset_to_tensor(nodes, ss, device=out.device) + + assert torch.allclose(out, expected_samples) + if __name__ == "__main__": unittest.main()