Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 77 additions & 11 deletions dwave/plugins/torch/samplers/dimod_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any

import torch
from dimod import Sampler
import dimod
from hybrid.composers import AggregatedSamples

from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
Expand Down Expand Up @@ -93,18 +93,84 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor:
interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will
be sampled; entries with +/-1 values will remain constant.
"""
if x is not None:
raise NotImplementedError("Support for conditional sampling has not been implemented.")

h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range)
self._sample_set = AggregatedSamples.spread(
self._sampler.sample_ising(h, J, **self._sampler_params)
)

# use same device as modules linear
device = self._grbm._linear.device
return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device)
nodes = self._grbm.nodes
n_nodes = self._grbm.n_nodes

h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range)

# Unconditional sampling
if x is None:
self._sample_set = AggregatedSamples.spread(
self._sampler.sample_ising(h, J, **self._sampler_params)
)
return sampleset_to_tensor(nodes, self._sample_set, device)

# Conditional sampling
if x.shape[1] != n_nodes:
raise ValueError(f"x must have shape (batch_size, {n_nodes})")

mask = ~torch.isnan(x)
if not torch.all(torch.isin(x[mask], torch.tensor([-1, 1], device=x.device))):
raise ValueError("x must contain only ±1 or NaN")

results = []
for i in range(x.shape[0]):
# Fresh BQM
bqm = dimod.BinaryQuadraticModel.from_ising(h, J)

# Build conditioning dict
conditioned = {node: int(x[i, j].item())
for j, node in enumerate(nodes) if mask[i, j]}
Comment on lines +118 to +124
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we iterate over x instead of range(x.shape[0])? Seems more natural


# Apply conditioning
if conditioned:
bqm.fix_variables(conditioned)

# Clip linear biases for remaining free variables
if self._linear_range is not None:
lb, ub = self._linear_range
for v in bqm.linear:
if bqm.linear[v] > ub:
bqm.set_linear(v, ub)
elif bqm.linear[v] < lb:
bqm.set_linear(v, lb)

# Clip quadratic biases
if self._quadratic_range is not None:
lb, ub = self._quadratic_range
for u, v, bias in bqm.iter_quadratic():
if bias > ub:
bqm.set_quadratic(u, v, ub)
elif bias < lb:
bqm.set_quadratic(u, v, lb)

# Handle fully clamped case
if bqm.num_variables == 0:
full = torch.tensor([conditioned[node] for node in nodes],
device=device, dtype=torch.float)
results.append(full)
continue
Comment on lines +149 to +153
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this up the for-loop to avoid unnecessary computations; continue sooner when possible


# Sample one configuration per input
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be documented in the docstring. It should also raise a warning if num_reads is supplied and overwritten.

sample_kwargs = dict(self._sampler_params)
sample_kwargs["num_reads"] = 1
sample_set = AggregatedSamples.spread(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this line?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you mean calling the spread? I was following the earlier use case. You're right, it's redundant. I'll remove it then.

self._sampler.sample(bqm, **sample_kwargs)
)

# Extract sampled values
sample = sample_set.first.sample

# Reconstruct full sample
full = torch.empty(n_nodes, device=device)
for j, node in enumerate(nodes):
full[j] = conditioned[node] if node in conditioned else float(sample[node])
results.append(full)
Comment on lines +163 to +169
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample's variable ordering may not be identical to that of grbm.nodes.
I'd add a test to verify correct ordering first.


# Stack to get (batch_size, n_nodes)
return torch.stack(results, dim=0)

@property
def sample_set(self) -> dimod.SampleSet:
"""The sample set returned from the latest sample call."""
Expand Down
47 changes: 46 additions & 1 deletion tests/test_samplers/test_dimod_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from dwave.samplers import SteepestDescentSampler
from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple

from dwave.samplers import SimulatedAnnealingSampler

class TestDimodSampler(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -125,6 +125,51 @@ def test_sample(self):
torch.tensor(list(tracker.input['J'].values())),
torch.tensor([0, 0, 0, 0.0])
)

with self.subTest("Conditional sampling preserves clamped variables"):
sampler = DimodSampler(
self.bm,
SimulatedAnnealingSampler(),
prefactor=1,
sample_kwargs=dict(num_reads=1)
)

x = torch.tensor([
[1.0, float("nan"), -1.0, float("nan")],
[float("nan"), -1.0, float("nan"), 1.0],
])

samples = sampler.sample(x)

# Shape check
self.assertTupleEqual(samples.shape, x.shape)

# Check clamped values unchanged
mask = ~torch.isnan(x)
self.assertTrue(torch.all(samples[mask] == x[mask]))

# Check free variables are ±1
free_mask = torch.isnan(x)
free_values = samples[free_mask]
self.assertTrue(torch.all(torch.isin(free_values, torch.tensor([-1.0, 1.0]))),
"Free variables should be sampled as ±1")

with self.subTest("Conditional sampling with all variables clamped returns input unchanged."):
sampler = DimodSampler(
self.bm,
SimulatedAnnealingSampler(),
prefactor=1,
sample_kwargs=dict(num_reads=1)
)

x = torch.tensor([
[+1.0, -1.0, -1.0, +1.0],
[-1.0, +1.0, -1.0, -1.0],
])

samples = sampler.sample(x)
# All spins clamped, should return identical tensor
torch.testing.assert_close(samples, x)

def test_sample_set(self):
grbm = GRBM(list("abcd"), [("a", "b")])
Expand Down