Conversation
kevinchern
left a comment
There was a problem hiding this comment.
Did a quick first pass over the main function.
| super().__init__() | ||
|
|
||
| def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | ||
| """Sample from the dimod sampler and return the corresponding tensor. |
There was a problem hiding this comment.
Update docstrings to reflect behaviour of conditional sampling
| for i in range(x.shape[0]): | ||
| # Fresh BQM | ||
| bqm = dimod.BinaryQuadraticModel.from_ising(h, J) | ||
|
|
||
| # Build conditioning dict | ||
| conditioned = {node: int(x[i, j].item()) | ||
| for j, node in enumerate(nodes) if mask[i, j]} |
There was a problem hiding this comment.
Can we iterate over x instead of range(x.shape[0])? Seems more natural
| if bqm.num_variables == 0: | ||
| full = torch.tensor([conditioned[node] for node in nodes], | ||
| device=device, dtype=torch.float) | ||
| results.append(full) | ||
| continue |
There was a problem hiding this comment.
I'd move this up the for-loop to avoid unnecessary computations; continue sooner when possible
| results.append(full) | ||
| continue | ||
|
|
||
| # Sample one configuration per input |
There was a problem hiding this comment.
This needs to be documented in the docstring. It should also raise a warning if num_reads is supplied and overwritten.
| # Sample one configuration per input | ||
| sample_kwargs = dict(self._sampler_params) | ||
| sample_kwargs["num_reads"] = 1 | ||
| sample_set = AggregatedSamples.spread( |
There was a problem hiding this comment.
What's the motivation for this line?
There was a problem hiding this comment.
Oh you mean calling the spread? I was following the earlier use case. You're right, it's redundant. I'll remove it then.
| sample = sample_set.first.sample | ||
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for j, node in enumerate(nodes): | ||
| full[j] = conditioned[node] if node in conditioned else float(sample[node]) | ||
| results.append(full) |
There was a problem hiding this comment.
sample's variable ordering may not be identical to that of grbm.nodes.
I'd add a test to verify correct ordering first.
| @@ -22,7 +22,7 @@ | |||
| from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler | |||
There was a problem hiding this comment.
This PR adds: