[WIP] Spectral-Grassmann OT#792
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #792 +/- ##
==========================================
- Coverage 96.77% 95.77% -1.00%
==========================================
Files 107 108 +1
Lines 22342 22621 +279
==========================================
+ Hits 21622 21666 +44
- Misses 720 955 +235 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Hello @osheasienna and @thibaut-germain this is a nice first step.
Here are below a few comments that we can discuss together
| return C | ||
|
|
||
|
|
||
| def metric( |
There was a problem hiding this comment.
| def metric( | |
| def sgot_metric( |
| return prod ** (q / 2) | ||
|
|
||
|
|
||
| def ot_plan(C, Ws=None, Wt=None, nx=None): |
There was a problem hiding this comment.
this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def cost( |
There was a problem hiding this comment.
| def cost( | |
| def sgot_cost_matrix( |
| imag_scale=1.0, | ||
| nx=None, | ||
| ): | ||
| """Compute the SGOT cost matrix between two spectral decompositions. |
There was a problem hiding this comment.
recall here the equation with eta and define with math teh different acceptable metrics
| raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") | ||
| lam2 = Dt | ||
|
|
||
| lam1 = nx.astype(lam1, "complex128") |
There was a problem hiding this comment.
is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?
| logits_s = rng.randn(r) | ||
| logits_t = rng.randn(r) | ||
|
|
||
| Ws = np.exp(logits_s) |
There was a problem hiding this comment.
simpler and return only positive values
| Ws = np.exp(logits_s) | |
| Ws = rng.rand(r) |
| """Create test_cost for each trial: sweep over HPs and run cost().""" | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
| n_trials = 10 | ||
| for _ in range(n_trials): |
| def test_hyperparameter_sweep(): | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
|
|
||
| for _ in range(10): |
| This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. | ||
|
|
||
| #### New features | ||
| - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) |
| ## Upcomming 0.9.7.post1 | ||
|
|
||
| #### New features | ||
| The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). |
There was a problem hiding this comment.
move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number
rflamary
left a comment
There was a problem hiding this comment.
A few comments from talking together
| if grassman_metric == "procrustes": | ||
| return 2.0 * (1.0 - delta) | ||
| if grassman_metric == "martin": | ||
| return -nx.log(nx.clip(delta**2, eps, 1e300)) |
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = C2 ** (p / 2.0) |
There was a problem hiding this comment.
| C = C2 ** (p / 2.0) | |
| C = nx.real(C2) ** (p / 2.0) |
| q=1, | ||
| r=2, | ||
| grassman_metric="chordal", | ||
| real_scale=1.0, |
There was a problem hiding this comment.
lets call this eigen_scaling and set it to None by default
| nx=None, | ||
| ): | ||
| """Compute the SGOT metric between two spectral decompositions. | ||
There was a problem hiding this comment.
add equation that illustrate p q and r
| import numpy as np | ||
| import pytest | ||
|
|
||
| from ot.backend import get_backend |
There was a problem hiding this comment.
| from ot.backend import get_backend | |
| from ot.backend import get_backend, torch, jax |
| rng = np.random.RandomState(0) | ||
|
|
||
|
|
||
| def rand_complex(shape): |
There was a problem hiding this comment.
| def rand_complex(shape): | |
| def rand_complex(shape,rng): |
| return real + 1j * imag | ||
|
|
||
|
|
||
| def random_atoms(d=8, r=4): |
There was a problem hiding this comment.
| def random_atoms(d=8, r=4): | |
| def random_atoms(d=8, r=4,seed=42): |
|
|
||
|
|
||
| @pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) | ||
| def test_cost_backend_consistency(backend_name): |
There was a problem hiding this comment.
| def test_cost_backend_consistency(backend_name): | |
| def test_cost_backend_consistency(nx): |
| # --------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_hyperparameter_sweep_cost(nx): |
There was a problem hiding this comment.
| def test_hyperparameter_sweep_cost(nx): | |
| def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta): |
Types of changes
Adding sgot file in the ot folder.
Motivation and context / Related issue
Keep track of SGOT implementation in POT.
How has this been tested (if it applies)
Not tested yet.
PR checklist