Skip to content

Add conditional sampling#68

Open
anahitamansouri wants to merge 5 commits intodwavesystems:mainfrom
anahitamansouri:feature/conditional-sampling
Open

Add conditional sampling#68
anahitamansouri wants to merge 5 commits intodwavesystems:mainfrom
anahitamansouri:feature/conditional-sampling

Conversation

@anahitamansouri
Copy link
Copy Markdown
Collaborator

This PR adds:

  1. Conditional sampling feature for block spin sampling.
  2. BipartiteSampler for sampling bipartite GRBMs.
  3. An example of using the BipartiteSampler.
  4. Tests for the new functionalities.

@anahitamansouri anahitamansouri self-assigned this Mar 9, 2026
@anahitamansouri anahitamansouri added the enhancement New feature or request label Mar 9, 2026
@anahitamansouri anahitamansouri marked this pull request as ready for review March 9, 2026 21:18
Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!!!! Did a quick first pass with minor requests. The implementation is clean and efficient, documentation is well written, and tests are thorough.
Missing implementation for DimodSampler but let's add that in a separate PR.

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a couple minor requests

torch.Tensor: A tensor of shape (num_chains, n_nodes) of +/-1 values sampled from the model.
"""
if x is not None:
mask = self._validate_input_and_generate_mask(x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask = self._validate_input_and_generate_mask(x)
self._validate_input(x)
mask = ~torch.isnan(x)

h = self._grbm.hidden_idx
self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
def _validate_input(self, x: torch.Tensor) -> None:

self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
"""Validate conditional sampling input and construct a boolean mask.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Validate conditional sampling input and construct a boolean mask.
"""Validate conditional sampling input.

Comment on lines +281 to +285

Returns:
torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where
``True`` indicates clamped variables (observed in ``x``) and
``False`` indicates variables that should be sampled (``NaN`` in x).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where
``True`` indicates clamped variables (observed in ``x``) and
``False`` indicates variables that should be sampled (``NaN`` in x).

"The input must be unclamped for visible or hidden but not both."
)

return mask
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return mask


Args:
x (torch.Tensor): A tensor of shape (``num_chains``, ``dim``) or (``num_chains``, ``n_nodes``)
interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will
interpreted as a batch of partially observed spins. Entries marked with ``torch.nan`` will

if mask is not None:
self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion here as in bipartite sampler (docstring, type hints, returns, and defining mask outside)

Copy link
Copy Markdown
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall we talked about this, but it seems like BipartiteGibbsSampler and BlockSampler could share a lot of methods and do with some deduplication. If they're not general enough to fit into TorchSampler, there should either be a hierarchy between them or another common class that they inherit from, or, especially if you foresee some of these methods being used in other samplers, you could create one (or several) mixin classes.

Copy link
Copy Markdown
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there!

grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"])

def crayon(n):
return 0 if n in ["v1", "v2"] else 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0 if n in ["v1", "v2"] else 1
return n in ["v1", "v2"]


sampler = BlockSampler(grbm, crayon, num_chains=2, schedule=[1.0])

# Case 1: Valid single block unclamped
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can each of these test cases be wrapped in individual subtests?

if mask is not None:
self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block])

def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this suggestion to separate validation and mask generation

Comment on lines +289 to +295
self.assertEqual(mask.shape, x_valid.shape)

# Chain 0: visible unclamped
self.assertTrue(mask[0, 2:].all()) # First chain: hidden spins are clamped

# Chain 1: hidden unclamped
self.assertTrue(mask[1, :2].all()) # Second chain: visible spins are clamped
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IF we keep the signature of validate_and_generate..., THEN an we combine these tests into one where the mask is hard-coded like expected_mask = torch.tensor([[False, ...], [...]])?

e.g., torch.testing.assertEqual(mask, expected_mask) or self.assertListEqual(mask.tolist(), expected_mask.tolist())

# Gibbs update for hidden block (block=1)
with self.subTest("hidden block Gibbs update"):
sampler._gibbs_update(0.0, hidden_block, ones*zero_field)
torch.testing.assert_close(torch.tensor(0.0), sampler._x.mean(), atol=1e-2, rtol=1e-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this one have a looser tolerance 1e-2 than the previous 1e-3?

Copy link
Copy Markdown
Collaborator Author

@anahitamansouri anahitamansouri Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed that there are fewer random variables in this test compared to the above one, so the estimate has higher variance, and I needed a looser tolerance (1e-2). I could avoid this by setting sampler._x.data[:] = 1.0 just like the earlier example. I can update the test if you think so.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something---don't both tests use sampler._x.mean() so the sample size should be the same(?)

def test_sample_conditional(self):
nodes = ["v1", "v2", "h1", "h2"]
edges = [["v1", "h1"], ["v1", "h2"], ["v2", "h1"], ["v2", "h2"]]
grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider setting the linear fields to be very large so the result is ~= deterministic.
Then, in the test cases, hard-code the expected results per conditional sampling step.

e.g.,
grbm.linear.data[:] = 99999999999999
and sampler._x.data[:] = 1
in one conditional step, everything but the clamped-states should become -1.

@kevinchern kevinchern self-requested a review March 26, 2026 18:40
Add conditional sampling functionality for the ``BlockSampler``.
- |
Add ``.clone()`` to the return of ``BlockSampler.sample`` to prevent
unintended in-place modification of the sampler’s internal state due to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unintended in-place modification of the samplers internal state due to
unintended in-place modification of the sampler's internal state due to

@anahitamansouri anahitamansouri requested a review from thisac April 2, 2026 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants