diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index 1e2c992..688a6d7 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -16,8 +16,9 @@ from typing import TYPE_CHECKING, Any import torch -from dimod import Sampler -from hybrid.composers import AggregatedSamples +import dimod +import warnings +from hybrid.composers import AggregatedSamples from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine from dwave.plugins.torch.samplers.base import TorchSampler @@ -86,25 +87,113 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: """Sample from the dimod sampler and return the corresponding tensor. The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` - which is overwritten by subsequent calls. + which is overwritten by subsequent calls. When ``x`` is provided (conditional sampling), exactly + one sample is drawn per input row. Any user-specified ``num_reads`` in ``sample_kwargs`` will be + ignored and overridden to 1. Args: x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will be sampled; entries with +/-1 values will remain constant. + Raises: + ValueError: If ``x`` has an invalid shape or contains values other than ±1 or NaN. + + Returns: + torch.Tensor: A tensor of shape (``batch_size``, ``n_nodes``) containing + sampled spin configurations with values in ``{-1, +1}``. """ - if x is not None: - raise NotImplementedError("Support for conditional sampling has not been implemented.") + device = self._grbm._linear.device + nodes = self._grbm.nodes + n_nodes = self._grbm.n_nodes h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) - self._sample_set = AggregatedSamples.spread( - self._sampler.sample_ising(h, J, **self._sampler_params) + + # Unconditional sampling + if x is None: + self._sample_set = AggregatedSamples.spread( + self._sampler.sample_ising(h, J, **self._sampler_params) + ) + return sampleset_to_tensor(nodes, self._sample_set, device) + + # Conditional sampling + if x.shape[1] != n_nodes: + raise ValueError(f"x must have shape (batch_size, {n_nodes})") + + mask = ~torch.isnan(x) + if not torch.all(torch.isin(x[mask], torch.tensor([-1, 1], device=x.device))): + raise ValueError("x must contain only ±1 or NaN") + + results = [] + for row, row_mask in zip(x, mask): + # Fresh BQM + bqm = dimod.BinaryQuadraticModel.from_ising(h, J) + + # Build conditioning dict + conditioned = {node: int(val.item()) + for node, val, m in zip(nodes, row, row_mask) if m + } + + # Apply conditioning + if conditioned: + bqm.fix_variables(conditioned) + + # Handle fully clamped case + if bqm.num_variables == 0: + full = torch.tensor([conditioned[node] for node in nodes], + device=device, dtype=torch.float) + results.append(full) + continue + + # Clip linear biases for remaining free variables + if self._linear_range is not None: + lb, ub = self._linear_range + for v in bqm.linear: + if bqm.linear[v] > ub: + bqm.set_linear(v, ub) + elif bqm.linear[v] < lb: + bqm.set_linear(v, lb) + + # Clip quadratic biases + if self._quadratic_range is not None: + lb, ub = self._quadratic_range + for u, v, bias in bqm.iter_quadratic(): + if bias > ub: + bqm.set_quadratic(u, v, ub) + elif bias < lb: + bqm.set_quadratic(u, v, lb) + + # Sample one configuration per input + sample_kwargs = dict(self._sampler_params) + if "num_reads" in sample_kwargs and sample_kwargs["num_reads"] != 1: + warnings.warn( + "`num_reads` is ignored during conditional sampling and set to 1 " + "(one sample per input row).", + UserWarning, + ) + sample_kwargs["num_reads"] = 1 + + sample_set = self._sampler.sample(bqm, **sample_kwargs) + + # Extract sampled values + sample = sample_set.first.sample + + # Reconstruct full sample + full = torch.empty(n_nodes, device=device) + for j, node in enumerate(nodes): + full[j] = conditioned[node] if node in conditioned else float(sample[node]) + results.append(full) + + # Stack to get (batch_size, n_nodes) + samples = torch.stack(results, dim=0) + energies = self._grbm(samples).detach().cpu().numpy() + + self._sample_set = dimod.SampleSet.from_samples( + (samples.cpu().numpy(), self._grbm.nodes), + vartype=dimod.SPIN, + energy=energies, ) - - # use same device as modules linear - device = self._grbm._linear.device - return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device) - + return samples + @property def sample_set(self) -> dimod.SampleSet: """The sample set returned from the latest sample call.""" diff --git a/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml b/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml new file mode 100644 index 0000000..347dc9c --- /dev/null +++ b/releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml @@ -0,0 +1,5 @@ +--- +upgrade: + - | + Add conditional sampling functionality for the ``DimodSampler``. The sampler + enforces one sample per input row during conditional sampling. diff --git a/tests/test_samplers/test_dimod_sampler.py b/tests/test_samplers/test_dimod_sampler.py index 59e7b88..8adec34 100644 --- a/tests/test_samplers/test_dimod_sampler.py +++ b/tests/test_samplers/test_dimod_sampler.py @@ -15,14 +15,13 @@ import unittest import torch -from dimod import SPIN, BinaryQuadraticModel, IdentitySampler, SampleSet, TrackingComposite -from parameterized import parameterized +from dimod import IdentitySampler, SampleSet, TrackingComposite from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler from dwave.samplers import SteepestDescentSampler -from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple - +from dwave.samplers import SimulatedAnnealingSampler +from dwave.plugins.torch.utils import sampleset_to_tensor class TestDimodSampler(unittest.TestCase): def setUp(self) -> None: @@ -125,6 +124,51 @@ def test_sample(self): torch.tensor(list(tracker.input['J'].values())), torch.tensor([0, 0, 0, 0.0]) ) + + with self.subTest("Conditional sampling preserves clamped variables"): + sampler = DimodSampler( + self.bm, + SimulatedAnnealingSampler(), + prefactor=1, + sample_kwargs=dict(num_reads=1) + ) + + x = torch.tensor([ + [1.0, float("nan"), -1.0, float("nan")], + [float("nan"), -1.0, float("nan"), 1.0], + ]) + + samples = sampler.sample(x) + + # Shape check + self.assertTupleEqual(samples.shape, x.shape) + + # Check clamped values unchanged + mask = ~torch.isnan(x) + self.assertTrue(torch.all(samples[mask] == x[mask])) + + # Check free variables are ±1 + free_mask = torch.isnan(x) + free_values = samples[free_mask] + self.assertTrue(torch.all(torch.isin(free_values, torch.tensor([-1.0, 1.0]))), + "Free variables should be sampled as ±1") + + with self.subTest("Conditional sampling with all variables clamped returns input unchanged."): + sampler = DimodSampler( + self.bm, + SimulatedAnnealingSampler(), + prefactor=1, + sample_kwargs=dict(num_reads=1) + ) + + x = torch.tensor([ + [+1.0, -1.0, -1.0, +1.0], + [-1.0, +1.0, -1.0, -1.0], + ]) + + samples = sampler.sample(x) + # All spins clamped, should return identical tensor + torch.testing.assert_close(samples, x) def test_sample_set(self): grbm = GRBM(list("abcd"), [("a", "b")]) @@ -142,6 +186,31 @@ def test_sample_set(self): with self.subTest("The `sample_set` attribute should be of type `dimod.SampleSet`."): self.assertTrue(isinstance(sampler.sample_set, SampleSet)) - + def test_order_sample(self): + + nodes = ["b", "a", "c"] + edges = [("a", "b"), ("b", "c")] + + grbm = GRBM(nodes, edges) + + sampler = DimodSampler( + grbm, + SimulatedAnnealingSampler(), + prefactor=1.0, + sample_kwargs=dict(num_reads=1) + ) + + x = torch.tensor([ + [float("nan"), float("nan"), -1.0], + ]) + + out = sampler.sample(x) + ss = sampler.sample_set + + # Check alignment with sampleset_to_tensor + expected_samples = sampleset_to_tensor(nodes, ss, device=out.device) + + assert torch.allclose(out, expected_samples) + if __name__ == "__main__": unittest.main()