Add maximum mean discrepancy and radial basis#35
Add maximum mean discrepancy and radial basis#35kevinchern merged 12 commits intodwavesystems:mainfrom
Conversation
|
@VolodyaCO what's the motivation for the following changes?
|
Addressed in meeting. For posterity:
I will review ASAP |
kevinchern
left a comment
There was a problem hiding this comment.
Thanks Vlad!! Nicely implemented and documented.
Let's add the MaximumMeanDiscrepancy a a module and this should be good to go.
thisac
left a comment
There was a problem hiding this comment.
Thanks @VolodyaCO and @kevinchern! Test fail due to Python 3.9 tests being run (removing 3.9 support in #49, which should fix it).
Unit tests should be expanded. There should at least be test classes for kernels (RBF) and more unit tests for the mmd function. Otherwise, looks good. Just a few, mostly minor, comments.
| __all__ = ["Kernel", "RBFKernel", "mmd_loss"] | ||
|
|
||
|
|
||
| class Kernel(nn.Module): |
There was a problem hiding this comment.
Should Kernels (and RBF) be in a kernels.py instead of mmd.py?
There was a problem hiding this comment.
For now, all kernels are for computing MMD losses. I don't know if that'll change in the future.
There was a problem hiding this comment.
If we expect it might change, it seems better to put kernels somewhere else. Alternatively remove Kernel and RBFKernel from __all__ since they're only (?) used within this module.
It seems to me like it's a more general concept and thus, if ever used outside of the mmd module, should be e.g., in torch.model.kernels or torch.kernels even if they're currently only used for calculating MMD losses. Even just putting them in a kernels.py, separating them from the mmd_loss() function, makes more sense to me.
There was a problem hiding this comment.
Good point Theo. I am in favour of organizing them in torch.kernels for reasons you mentioned, i.e., a kernel is a more general object that isn't limited to applications in MMD.
There was a problem hiding this comment.
In this case, I will move kernels to torch.kernels. Should we have torch.kernels and torch.functional.kernels? To store the function and the modules separately?
There was a problem hiding this comment.
I can add this^
| soft = logits | ||
| result = hard - soft.detach() + soft | ||
| # Now we need to repeat the result n_samples times along a new dimension | ||
| return repeat(result, "b ... -> b n ...", n=n_samples) |
There was a problem hiding this comment.
Do we absolutely need repeat here? Seems a bit cumbersome to add einops as a test dependency just for this test. 🤔
There was a problem hiding this comment.
No, but it's more readable than the pytorch-only version, which requires unsqueezing, inferring the number of feature dimensions and then repeating.
There was a problem hiding this comment.
chiming in to say einops is generally useful, e.g., #39
e4e652d to
39e9eec
Compare
964390c to
948e93e
Compare
Co-authored-by: Vladimir Vargas Calderón <vvargasc@dwavesys.com>
948e93e to
36117c4
Compare
|
Just did another pass. Main changes are:
Food for thought (will add an issue): the Kernel base class does not enforce that a kernel is PSD. Edit RE PSD guarantee: Had a brief exchange with Vlad and figured the onus is on developer to correctly define a kernel |
thisac
left a comment
There was a problem hiding this comment.
Overall, looks good! Just a few comments/suggestions.
Co-Authored-By: Theodor Isacsson <theodor@isacsson.ca>
Co-authored-by: Theodor Isacsson <theodor@isacsson.ca>
TODOs:
store_config