-
Notifications
You must be signed in to change notification settings - Fork 18
Updating Dimod Sampler #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import torch | ||
| from dimod import Sampler | ||
| import dimod | ||
| from hybrid.composers import AggregatedSamples | ||
|
|
||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
|
|
@@ -93,18 +93,84 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | |
| interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will | ||
| be sampled; entries with +/-1 values will remain constant. | ||
| """ | ||
| if x is not None: | ||
| raise NotImplementedError("Support for conditional sampling has not been implemented.") | ||
|
|
||
| h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) | ||
| self._sample_set = AggregatedSamples.spread( | ||
| self._sampler.sample_ising(h, J, **self._sampler_params) | ||
| ) | ||
|
|
||
| # use same device as modules linear | ||
| device = self._grbm._linear.device | ||
| return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device) | ||
| nodes = self._grbm.nodes | ||
| n_nodes = self._grbm.n_nodes | ||
|
|
||
| h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) | ||
|
|
||
| # Unconditional sampling | ||
| if x is None: | ||
| self._sample_set = AggregatedSamples.spread( | ||
| self._sampler.sample_ising(h, J, **self._sampler_params) | ||
| ) | ||
| return sampleset_to_tensor(nodes, self._sample_set, device) | ||
|
|
||
| # Conditional sampling | ||
| if x.shape[1] != n_nodes: | ||
| raise ValueError(f"x must have shape (batch_size, {n_nodes})") | ||
|
|
||
| mask = ~torch.isnan(x) | ||
| if not torch.all(torch.isin(x[mask], torch.tensor([-1, 1], device=x.device))): | ||
| raise ValueError("x must contain only ±1 or NaN") | ||
|
|
||
| results = [] | ||
| for i in range(x.shape[0]): | ||
| # Fresh BQM | ||
| bqm = dimod.BinaryQuadraticModel.from_ising(h, J) | ||
|
|
||
| # Build conditioning dict | ||
| conditioned = {node: int(x[i, j].item()) | ||
| for j, node in enumerate(nodes) if mask[i, j]} | ||
|
|
||
| # Apply conditioning | ||
| if conditioned: | ||
| bqm.fix_variables(conditioned) | ||
|
|
||
| # Clip linear biases for remaining free variables | ||
| if self._linear_range is not None: | ||
| lb, ub = self._linear_range | ||
| for v in bqm.linear: | ||
| if bqm.linear[v] > ub: | ||
| bqm.set_linear(v, ub) | ||
| elif bqm.linear[v] < lb: | ||
| bqm.set_linear(v, lb) | ||
|
|
||
| # Clip quadratic biases | ||
| if self._quadratic_range is not None: | ||
| lb, ub = self._quadratic_range | ||
| for u, v, bias in bqm.iter_quadratic(): | ||
| if bias > ub: | ||
| bqm.set_quadratic(u, v, ub) | ||
| elif bias < lb: | ||
| bqm.set_quadratic(u, v, lb) | ||
|
|
||
| # Handle fully clamped case | ||
| if bqm.num_variables == 0: | ||
| full = torch.tensor([conditioned[node] for node in nodes], | ||
| device=device, dtype=torch.float) | ||
| results.append(full) | ||
| continue | ||
|
Comment on lines
+149
to
+153
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd move this up the for-loop to avoid unnecessary computations; continue sooner when possible |
||
|
|
||
| # Sample one configuration per input | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be documented in the docstring. It should also raise a warning if |
||
| sample_kwargs = dict(self._sampler_params) | ||
| sample_kwargs["num_reads"] = 1 | ||
| sample_set = AggregatedSamples.spread( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the motivation for this line?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh you mean calling the |
||
| self._sampler.sample(bqm, **sample_kwargs) | ||
| ) | ||
|
|
||
| # Extract sampled values | ||
| sample = sample_set.first.sample | ||
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for j, node in enumerate(nodes): | ||
| full[j] = conditioned[node] if node in conditioned else float(sample[node]) | ||
| results.append(full) | ||
|
Comment on lines
+163
to
+169
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| # Stack to get (batch_size, n_nodes) | ||
| return torch.stack(results, dim=0) | ||
|
|
||
| @property | ||
| def sample_set(self) -> dimod.SampleSet: | ||
| """The sample set returned from the latest sample call.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from dwave.samplers import SteepestDescentSampler | ||
| from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple | ||
|
|
||
| from dwave.samplers import SimulatedAnnealingSampler | ||
|
|
||
| class TestDimodSampler(unittest.TestCase): | ||
| def setUp(self) -> None: | ||
|
|
@@ -125,6 +125,51 @@ def test_sample(self): | |
| torch.tensor(list(tracker.input['J'].values())), | ||
| torch.tensor([0, 0, 0, 0.0]) | ||
| ) | ||
|
|
||
| with self.subTest("Conditional sampling preserves clamped variables"): | ||
| sampler = DimodSampler( | ||
| self.bm, | ||
| SimulatedAnnealingSampler(), | ||
| prefactor=1, | ||
| sample_kwargs=dict(num_reads=1) | ||
| ) | ||
|
|
||
| x = torch.tensor([ | ||
| [1.0, float("nan"), -1.0, float("nan")], | ||
| [float("nan"), -1.0, float("nan"), 1.0], | ||
| ]) | ||
|
|
||
| samples = sampler.sample(x) | ||
|
|
||
| # Shape check | ||
| self.assertTupleEqual(samples.shape, x.shape) | ||
|
|
||
| # Check clamped values unchanged | ||
| mask = ~torch.isnan(x) | ||
| self.assertTrue(torch.all(samples[mask] == x[mask])) | ||
|
|
||
| # Check free variables are ±1 | ||
| free_mask = torch.isnan(x) | ||
| free_values = samples[free_mask] | ||
| self.assertTrue(torch.all(torch.isin(free_values, torch.tensor([-1.0, 1.0]))), | ||
| "Free variables should be sampled as ±1") | ||
|
|
||
| with self.subTest("Conditional sampling with all variables clamped returns input unchanged."): | ||
| sampler = DimodSampler( | ||
| self.bm, | ||
| SimulatedAnnealingSampler(), | ||
| prefactor=1, | ||
| sample_kwargs=dict(num_reads=1) | ||
| ) | ||
|
|
||
| x = torch.tensor([ | ||
| [+1.0, -1.0, -1.0, +1.0], | ||
| [-1.0, +1.0, -1.0, -1.0], | ||
| ]) | ||
|
|
||
| samples = sampler.sample(x) | ||
| # All spins clamped, should return identical tensor | ||
| torch.testing.assert_close(samples, x) | ||
|
|
||
| def test_sample_set(self): | ||
| grbm = GRBM(list("abcd"), [("a", "b")]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we iterate over
xinstead ofrange(x.shape[0])? Seems more natural