Conversation
kevinchern
left a comment
There was a problem hiding this comment.
Added a couple minor requests
| torch.Tensor: A tensor of shape (num_chains, n_nodes) of +/-1 values sampled from the model. | ||
| """ | ||
| if x is not None: | ||
| mask = self._validate_input_and_generate_mask(x) |
There was a problem hiding this comment.
| mask = self._validate_input_and_generate_mask(x) | |
| self._validate_input(x) | |
| mask = ~torch.isnan(x) |
| h = self._grbm.hidden_idx | ||
| self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h]) | ||
|
|
||
| def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
| def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor: | |
| def _validate_input(self, x: torch.Tensor) -> None: |
| self._x[:, h] = torch.where(mask[:, h], x[:, h], self._x[:, h]) | ||
|
|
||
| def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Validate conditional sampling input and construct a boolean mask. |
There was a problem hiding this comment.
| """Validate conditional sampling input and construct a boolean mask. | |
| """Validate conditional sampling input. |
|
|
||
| Returns: | ||
| torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where | ||
| ``True`` indicates clamped variables (observed in ``x``) and | ||
| ``False`` indicates variables that should be sampled (``NaN`` in x). |
There was a problem hiding this comment.
| Returns: | |
| torch.Tensor: Boolean mask of shape ``(num_chains, n_nodes)`` where | |
| ``True`` indicates clamped variables (observed in ``x``) and | |
| ``False`` indicates variables that should be sampled (``NaN`` in x). |
| "The input must be unclamped for visible or hidden but not both." | ||
| ) | ||
|
|
||
| return mask |
There was a problem hiding this comment.
| return mask |
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor of shape (``num_chains``, ``dim``) or (``num_chains``, ``n_nodes``) | ||
| interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will |
There was a problem hiding this comment.
| interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will | |
| interpreted as a batch of partially observed spins. Entries marked with ``torch.nan`` will |
| if mask is not None: | ||
| self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block]) | ||
|
|
||
| def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Same suggestion here as in bipartite sampler (docstring, type hints, returns, and defining mask outside)
thisac
left a comment
There was a problem hiding this comment.
I recall we talked about this, but it seems like BipartiteGibbsSampler and BlockSampler could share a lot of methods and do with some deduplication. If they're not general enough to fit into TorchSampler, there should either be a hierarchy between them or another common class that they inherit from, or, especially if you foresee some of these methods being used in other samplers, you could create one (or several) mixin classes.
tests/test_block_sampler.py
Outdated
| grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"]) | ||
|
|
||
| def crayon(n): | ||
| return 0 if n in ["v1", "v2"] else 1 |
There was a problem hiding this comment.
| return 0 if n in ["v1", "v2"] else 1 | |
| return n in ["v1", "v2"] |
|
|
||
| sampler = BlockSampler(grbm, crayon, num_chains=2, schedule=[1.0]) | ||
|
|
||
| # Case 1: Valid single block unclamped |
There was a problem hiding this comment.
Can each of these test cases be wrapped in individual subtests?
| if mask is not None: | ||
| self._x[:, block] = torch.where(mask[:, block], x[:, block], self._x[:, block]) | ||
|
|
||
| def _validate_input_and_generate_mask(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Bumping this suggestion to separate validation and mask generation
tests/test_block_sampler.py
Outdated
| self.assertEqual(mask.shape, x_valid.shape) | ||
|
|
||
| # Chain 0: visible unclamped | ||
| self.assertTrue(mask[0, 2:].all()) # First chain: hidden spins are clamped | ||
|
|
||
| # Chain 1: hidden unclamped | ||
| self.assertTrue(mask[1, :2].all()) # Second chain: visible spins are clamped |
There was a problem hiding this comment.
IF we keep the signature of validate_and_generate..., THEN an we combine these tests into one where the mask is hard-coded like expected_mask = torch.tensor([[False, ...], [...]])?
e.g., torch.testing.assertEqual(mask, expected_mask) or self.assertListEqual(mask.tolist(), expected_mask.tolist())
tests/test_bipartite_sampler.py
Outdated
| # Gibbs update for hidden block (block=1) | ||
| with self.subTest("hidden block Gibbs update"): | ||
| sampler._gibbs_update(0.0, hidden_block, ones*zero_field) | ||
| torch.testing.assert_close(torch.tensor(0.0), sampler._x.mean(), atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
Why does this one have a looser tolerance 1e-2 than the previous 1e-3?
There was a problem hiding this comment.
Yeah, I noticed that there are fewer random variables in this test compared to the above one, so the estimate has higher variance, and I needed a looser tolerance (1e-2). I could avoid this by setting sampler._x.data[:] = 1.0 just like the earlier example. I can update the test if you think so.
There was a problem hiding this comment.
I might be missing something---don't both tests use sampler._x.mean() so the sample size should be the same(?)
tests/test_bipartite_sampler.py
Outdated
| def test_sample_conditional(self): | ||
| nodes = ["v1", "v2", "h1", "h2"] | ||
| edges = [["v1", "h1"], ["v1", "h2"], ["v2", "h1"], ["v2", "h2"]] | ||
| grbm = GRBM(nodes, edges, hidden_nodes=["h1", "h2"]) |
There was a problem hiding this comment.
consider setting the linear fields to be very large so the result is ~= deterministic.
Then, in the test cases, hard-code the expected results per conditional sampling step.
e.g.,
grbm.linear.data[:] = 99999999999999
and sampler._x.data[:] = 1
in one conditional step, everything but the clamped-states should become -1.
| Add conditional sampling functionality for the ``BlockSampler``. | ||
| - | | ||
| Add ``.clone()`` to the return of ``BlockSampler.sample`` to prevent | ||
| unintended in-place modification of the sampler’s internal state due to |
There was a problem hiding this comment.
| unintended in-place modification of the sampler’s internal state due to | |
| unintended in-place modification of the sampler's internal state due to |
This PR adds: