From f395e5819b5975f24d842a966084640951a44337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 14 Jan 2024 18:05:37 +0100 Subject: [PATCH 01/30] [MRG] fix gpu compatibility of srGW solvers (#596) * fix gpu compatibility of srgw solvers * update release and pep8 --- README.md | 2 +- RELEASES.md | 4 ++-- ot/gromov/_semirelaxed.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index d3385eb3d..3be8c1b07 100644 --- a/README.md +++ b/README.md @@ -354,4 +354,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. -[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). \ No newline at end of file +[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). diff --git a/RELEASES.md b/RELEASES.md index 488366ae3..7777e9011 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,7 +4,7 @@ #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - +- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) ## 0.9.2 *December 2023* @@ -671,4 +671,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index d064dc669..fb9d2b3ca 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme else: q = nx.sum(G0, 0) # Check first marginal of G0 - np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08) constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) @@ -363,8 +363,8 @@ def semirelaxed_fused_gromov_wasserstein( G0 = nx.outer(p, q) else: q = nx.sum(G0, 0) - # Check marginals of G0 - np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + # Check first marginal of G0 + assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08) constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) @@ -703,7 +703,7 @@ def entropic_semirelaxed_gromov_wasserstein( else: q = nx.sum(G0, 0) # Check first marginal of G0 - np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08) constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) @@ -951,7 +951,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( else: q = nx.sum(G0, 0) # Check first marginal of G0 - np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08) constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) From 64c837426d9586906ec7f8e3a1237de0bc53f7b3 Mon Sep 17 00:00:00 2001 From: chattershuts <10526185@polimi.it> Date: Wed, 17 Jan 2024 10:30:03 +0100 Subject: [PATCH 02/30] Update README.md (#595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Flamary --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3be8c1b07..5fa4a6faa 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![Anaconda downloads](https://anaconda.org/conda-forge/pot/badges/downloads.svg)](https://anaconda.org/conda-forge/pot) [![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/PythonOT/POT/blob/master/LICENSE) -This open source Python library provide several solvers for optimization +This open source Python library provides several solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning. From c84ef332ea50a5f94a72f435a7828989c15fa4b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 18 Jan 2024 13:27:19 +0100 Subject: [PATCH 03/30] [MRG] fix doc+example lowrank sinkhorn (#601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix doc+example lowrank sinkhorn * fix autosummary for lowrank doc * update release --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 1 + docs/source/all.rst | 1 + examples/others/plot_lowrank_sinkhorn.py | 19 +++++++------------ ot/__init__.py | 5 +++-- ot/lowrank.py | 17 +++++++++-------- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 7777e9011..998d56836 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) +- Fix doc and example for lowrank sinkhorn (PR #601) ## 0.9.2 *December 2023* diff --git a/docs/source/all.rst b/docs/source/all.rst index 872a48528..91bb36361 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -24,6 +24,7 @@ API and modules gaussian gnn gromov + lowrank lp mapping optim diff --git a/examples/others/plot_lowrank_sinkhorn.py b/examples/others/plot_lowrank_sinkhorn.py index ece35b295..0664a829c 100644 --- a/examples/others/plot_lowrank_sinkhorn.py +++ b/examples/others/plot_lowrank_sinkhorn.py @@ -88,40 +88,35 @@ #%% # Plot sinkhorn vs low rank sinkhorn -pl.figure(1, figsize=(10, 4)) +pl.figure(1, figsize=(10, 8)) -pl.subplot(1, 3, 1) +pl.subplot(2, 3, 1) pl.imshow(list_P_Sin[0], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.05)') -pl.subplot(1, 3, 2) +pl.subplot(2, 3, 2) pl.imshow(list_P_Sin[1], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.005)') -pl.subplot(1, 3, 3) +pl.subplot(2, 3, 3) pl.imshow(list_P_Sin[2], interpolation='nearest') pl.axis('off') pl.title('Sinkhorn (reg=0.001)') pl.show() - -#%% - -pl.figure(2, figsize=(10, 4)) - -pl.subplot(1, 3, 1) +pl.subplot(2, 3, 4) pl.imshow(list_P_LR[0], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=3)') -pl.subplot(1, 3, 2) +pl.subplot(2, 3, 5) pl.imshow(list_P_LR[1], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=10)') -pl.subplot(1, 3, 3) +pl.subplot(2, 3, 6) pl.imshow(list_P_LR[2], interpolation='nearest') pl.axis('off') pl.title('Low rank (rank=50)') diff --git a/ot/__init__.py b/ot/__init__.py index f364357f2..db49d6c34 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,7 +5,7 @@ :py:mod:`ot.utils`, :py:mod:`ot.datasets`, :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` - , :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`. + , :py:mod:`ot.unbalanced`, :py:mod:`ot.mapping` . The following sub-modules are not imported due to additional dependencies: - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.plot` : depends on :code:`matplotlib` @@ -71,4 +71,5 @@ 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', + 'lowrank_sinkhorn'] diff --git a/ot/lowrank.py b/ot/lowrank.py index f6c1469bd..a06c1aaa1 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -319,17 +319,18 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re The function solves the following optimization problem: .. math:: - \mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - - \mathrm{reg} \cdot H((Q,R,g)) + \mathop{\inf_{(\mathbf{Q},\mathbf{R},\mathbf{g}) \in \mathcal{C}(\mathbf{a},\mathbf{b},r)}} \langle \mathbf{C}, \mathbf{Q}\mathrm{diag}(1/\mathbf{g})\mathbf{R}^\top \rangle - + \mathrm{reg} \cdot H((\mathbf{Q}, \mathbf{R}, \mathbf{g})) where : - - :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. - - :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan - - :math: `g` is the weight vector for the low-rank decomposition of the OT plan + + - :math:`\mathbf{C}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))` is the values of the three respective entropies evaluated for each term. + - :math:`\mathbf{Q}` and :math:`\mathbf{R}` are the low-rank matrix decomposition of the OT plan + - :math:`\mathbf{g}` is the weight vector for the low-rank decomposition of the OT plan - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - :math: `r` is the rank of the OT plan - - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem + - :math:`r` is the rank of the OT plan + - :math:`\mathcal{C}(\mathbf{a}, \mathbf{b}, r)` are the low-rank couplings of the OT problem Parameters From 6f358042f56662eb69c7b282a567790ee604c698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Feb 2024 06:46:11 -0600 Subject: [PATCH 04/30] [MRG] Add implicit Sinkhorn gradients (#605) * add detach function to backend * debug function * better detach * new implementation * add test for gradient * better default * update documentation --- ot/backend.py | 62 +++++++++++++++++++++----------------------- ot/solvers.py | 48 +++++++++++++++++++++++++++++++--- test/test_backend.py | 17 +++++++++--- test/test_solvers.py | 43 ++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 40 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7645c4237..9cc6446bf 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -281,6 +281,19 @@ def set_gradients(self, val, inputs, grads): """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() + def detach(self, *arrays): + """Detach the tensors from the computation graph + + See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html""" + if len(arrays) == 1: + return self._detach(arrays[0]) + else: + return [self._detach(array) for array in arrays] + + def _detach(self, a): + """Detach the tensor from the computation graph""" + raise NotImplementedError() + def zeros(self, shape, type_as=None): r""" Creates a tensor full of zeros. @@ -1027,14 +1040,6 @@ def transpose(self, a, axes=None): """ raise NotImplementedError() - def detach(self, *args): - r""" - Detach tensors in arguments from the current graph. - - See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - """ - raise NotImplementedError() - def matmul(self, a, b): r""" Matrix product of two arrays. @@ -1082,6 +1087,10 @@ def set_gradients(self, val, inputs, grads): # No gradients for numpy return val + def _detach(self, a): + # No gradients for numpy + return a + def zeros(self, shape, type_as=None): if type_as is None: return np.zeros(shape) @@ -1392,11 +1401,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return np.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return args[0] - return args - def matmul(self, a, b): return np.matmul(a, b) @@ -1462,6 +1466,9 @@ def set_gradients(self, val, inputs, grads): val, = jax.tree_map(lambda z: z + aux, (val,)) return val + def _detach(self, a): + return jax.lax.stop_gradient(a) + def zeros(self, shape, type_as=None): if type_as is None: return jnp.zeros(shape) @@ -1765,11 +1772,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return jnp.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return jax.lax.stop_gradient((args[0],))[0] - return [jax.lax.stop_gradient((a,))[0] for a in args] - def matmul(self, a, b): return jnp.matmul(a, b) @@ -1851,6 +1853,9 @@ def set_gradients(self, val, inputs, grads): return res + def _detach(self, a): + return a.detach() + def zeros(self, shape, type_as=None): if isinstance(shape, int): shape = (shape,) @@ -2256,11 +2261,6 @@ def transpose(self, a, axes=None): axes = tuple(range(a.ndim)[::-1]) return a.permute(axes) - def detach(self, *args): - if len(args) == 1: - return args[0].detach() - return [a.detach() for a in args] - def matmul(self, a, b): return torch.matmul(a, b) @@ -2312,6 +2312,9 @@ def set_gradients(self, val, inputs, grads): # No gradients for cupy return val + def _detach(self, a): + return a + def zeros(self, shape, type_as=None): if isinstance(shape, (list, tuple)): shape = tuple(int(i) for i in shape) @@ -2657,11 +2660,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return cp.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return args[0] - return args - def matmul(self, a, b): return cp.matmul(a, b) @@ -2729,6 +2727,9 @@ def grad(upstream): return val, grad return tmp(inputs) + def _detach(self, a): + return tf.stop_gradient(a) + def zeros(self, shape, type_as=None): if type_as is None: return tnp.zeros(shape) @@ -3083,11 +3084,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return tf.transpose(a, perm=axes) - def detach(self, *args): - if len(args) == 1: - return tf.stop_gradient(args[0]) - return [tf.stop_gradient(a) for a in args] - def matmul(self, a, b): return tnp.matmul(a, b) diff --git a/ot/solvers.py b/ot/solvers.py index e4eca9575..de817d7f7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -29,7 +29,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): + potentials_init=None, tol=None, verbose=False, grad='autodiff'): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem @@ -79,6 +79,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'implicit' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. Returns ------- @@ -134,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') + # Use implicit differentiation for memory saving + res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `implicit` mode computes the implicit gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: @@ -297,6 +313,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if reg_type.lower() in ['entropy', 'kl']: + if grad == 'implicit': # if implicit then detach the input + M0, a0, b0 = M, a, b + M, a, b = nx.detach(M, a, b) + # default values for sinkhorn if max_iter is None: max_iter = 1000 @@ -316,6 +336,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['log_u'], log['log_v']) + if grad == 'implicit': # set the gradient at convergence + + value = nx.set_gradients(value, (M0, a0, b0), + (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) + elif reg_type.lower() == 'l2': if max_iter is None: @@ -869,7 +894,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, - potentials_init=None, X_init=None, tol=None, verbose=False): + potentials_init=None, X_init=None, tol=None, verbose=False, + grad='autodiff'): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -935,6 +961,12 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'implicit' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. Returns ------- @@ -1002,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # lazy OT plan lazy_plan = res.lazy_plan + # Use implicit differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `implicit` mode computes the implicit gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + We also have a very efficient solver with compiled CPU/CUDA code using geomloss/PyKeOps that can be used with the following code: @@ -1189,7 +1231,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) return res diff --git a/test/test_backend.py b/test/test_backend.py index 3bc1e5480..da7293821 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -266,6 +266,14 @@ def test_empty_backend(): nx.matmul(M, M.T) with pytest.raises(NotImplementedError): nx.nan_to_num(M) + with pytest.raises(NotImplementedError): + nx.sign(M) + with pytest.raises(NotImplementedError): + nx.dtype_device(M) + with pytest.raises(NotImplementedError): + nx.assert_same_dtype_device(M, M) + with pytest.raises(NotImplementedError): + nx.eigh(M) def test_func_backends(nx): @@ -311,6 +319,11 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') + A = nx.detach(Mb) + A, B = nx.detach(Mb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('detach') + A = nx.zeros((10, 3)) A = nx.zeros((10, 3), type_as=Mb) lst_b.append(nx.to_numpy(A)) @@ -652,10 +665,6 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") - A = nx.detach(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append("detach") - A, B = nx.detach(Mb, Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("detach A") diff --git a/test/test_solvers.py b/test/test_solvers.py index 164989811..168b111e4 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -12,6 +12,7 @@ import ot from ot.bregman import geomloss +from ot.backend import torch lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] @@ -107,6 +108,48 @@ def test_solve(nx): sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_implicit(): + + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=10, grad='implicit') + sol0.value.backward() + + gM0 = M.grad.clone() + ga0 = a.grad.clone() + gb0 = b.grad.clone() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol = ot.solve(M, a, b, reg=10, grad='autodiff') + sol.value.backward() + + gM = M.grad.clone() + ga = a.grad.clone() + gb = b.grad.clone() + + # Note, gradients aer invariant to change in constant so we center them + assert torch.allclose(gM0, gM) + assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) + assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + + @pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): n_samples_s = 10 From 737b20db2f8db31eecdb1e78b610db92a21fdf12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 29 Feb 2024 16:40:44 +0100 Subject: [PATCH 05/30] [MRG] Temporary limit on jax version (#608) * Update requirements.txt Limit the max version of jax to allow for old config because of pymanopt using deprecated version * Update requirements.txt --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6af50f127..bffaf892f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,11 +6,11 @@ pymanopt cvxopt scikit-learn torch -jax -jaxlib +jax<=0.4.24 +jaxlib<=0.4.24 tensorflow pytest torch_geometric cvxpy geomloss -pykeops \ No newline at end of file +pykeops From f1fe5932c26317d968f53cc6ab490f702a41d598 Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Thu, 29 Feb 2024 23:37:14 +0100 Subject: [PATCH 06/30] [MRG] Faster gromov-wasserstein linesearch for symmetric matrices (#607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * improved gw linsearch for symmetric case * add a demo in examples * remove example * releases.md updated --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 3 +++ ot/gromov/_gw.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 998d56836..a695d6a70 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,6 +2,9 @@ ## 0.9.3 +#### New features ++ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. + #### Closed issues - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 69dd3df0c..3d7a47480 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -171,7 +171,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(C10): warnings.warn( @@ -479,7 +479,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(M0): warnings.warn( "Input feature matrix consists of integer. The transport plan will be " @@ -647,7 +647,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, - alpha_min=None, alpha_max=None, nx=None, **kwargs): + alpha_min=None, alpha_max=None, nx=None, symmetric=False, **kwargs): """ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. @@ -676,6 +676,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, Maximum value for alpha nx : backend, optional If let to its default value None, a backend test will be conducted. + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is False. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + Returns ------- alpha : float @@ -708,7 +712,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, dot = nx.dot(nx.dot(C1, deltaG), C2.T) a = - reg * nx.sum(dot * deltaG) - b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + if symmetric: + b = nx.sum(M * deltaG) - 2 * reg * nx.sum(dot * G) + else: + b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: From 0573ebab57c85ed7787f079f8d443ce0d2f87379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 1 Mar 2024 16:08:47 +0100 Subject: [PATCH 07/30] [MRG] Fix bug in emd2 with empty weighs on backends (#606) * fix buf emd2 for empty inputs * update release file * debug problems in optimization hen using list_to_arry by removing it everywhere * update jax config in tests * hopefully final fix --- RELEASES.md | 3 ++- ot/__init__.py | 2 +- ot/gromov/_gw.py | 2 -- ot/gromov/_semirelaxed.py | 2 -- ot/lp/__init__.py | 50 +++++++++++++++++++++++---------------- ot/lp/solver_1d.py | 6 ++++- ot/optim.py | 4 +--- ot/utils.py | 24 ++++++++++++++++--- test/conftest.py | 2 +- test/test_ot.py | 4 ++++ test/test_utils.py | 12 ++++++++++ 11 files changed, 77 insertions(+), 34 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index a695d6a70..3b8513dbd 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,6 +1,6 @@ # Releases -## 0.9.3 +## 0.9.3dev #### New features + `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. @@ -9,6 +9,7 @@ - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) +- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) ## 0.9.2 *December 2023* diff --git a/ot/__init__.py b/ot/__init__.py index db49d6c34..9a63b5f6f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,7 +58,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.3" +__version__ = "0.9.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 3d7a47480..281ed5f0b 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -703,8 +703,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index fb9d2b3ca..c37ba2bf4 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Gromov-Wasserstein". NeurIPS 2023 Workshop OTML. """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 545d1d8cd..93316a6c1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c ot.optim.cg : General regularized OT """ - # convert to numpy if list a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1, """ a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index e792db904..d9395c8d4 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + x_a, x_b = list_to_array(x_a, x_b) nx = get_backend(x_a, x_b) + if a is not None: + a = list_to_array(a, nx=nx) + if b is not None: + b = list_to_array(b, nx=nx) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" diff --git a/ot/optim.py b/ot/optim.py index 8700f75d1..dcdef6a88 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,7 +12,6 @@ import warnings from .lp import emd from .bregman import sinkhorn -from .utils import list_to_array from .backend import get_backend with warnings.catch_warnings(): @@ -73,7 +72,6 @@ def line_search_armijo( """ if nx is None: - xk, pk, gfk = list_to_array(xk, pk, gfk) xk0, pk0 = xk, pk nx = get_backend(xk0, pk0) else: @@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea ot.lp.emd : Unregularized optimal transport ot.bregman.sinkhorn : Entropic regularized optimal transport """ - a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): nx = get_backend(a, b) else: diff --git a/ot/utils.py b/ot/utils.py index 19e61f1fe..404a9f2db 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -56,12 +56,30 @@ def laplacian(x): return L -def list_to_array(*lst): +def list_to_array(*lst, nx=None): r""" Convert a list if in numpy format """ + lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] + if nx is None: # find backend + + if len(lst_not_empty) == 0: + type_as = np.zeros(0) + nx = get_backend(type_as) + else: + nx = get_backend(*lst_not_empty) + type_as = lst_not_empty[0] + else: + if len(lst_not_empty) == 0: + type_as = None + else: + type_as = lst_not_empty[0] if len(lst) > 1: - return [np.array(a) if isinstance(a, list) else a for a in lst] + return [nx.from_numpy(np.array(a), type_as=type_as) + if isinstance(a, list) else a for a in lst] else: - return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + if isinstance(lst[0], list): + return nx.from_numpy(np.array(lst[0]), type_as=type_as) + else: + return lst[0] def proj_simplex(v, z=1): diff --git a/test/conftest.py b/test/conftest.py index 0303ed9f2..043c8ca70 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -13,7 +13,7 @@ if jax: os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - from jax.config import config + from jax import config config.update("jax_enable_x64", True) if tf: diff --git a/test/test_ot.py b/test/test_ot.py index 5c6e6732b..a90321d5f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -74,7 +74,11 @@ def test_emd2_backends(nx): valb = ot.emd2(ab, ab, Mb) + # check with empty inputs + valb2 = ot.emd2([], [], Mb) + np.allclose(val, nx.to_numpy(valb)) + np.allclose(val, nx.to_numpy(valb2)) def test_emd_emd2_types_devices(nx): diff --git a/test/test_utils.py b/test/test_utils.py index 6cdb7ead7..966cef989 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -322,6 +322,18 @@ def test_cost_normalization(nx): ot.utils.cost_normalization(C1, 'error') +def test_list_to_array(nx): + + lst = [np.array([1, 2, 3]), np.array([4, 5, 6])] + + a1, a2 = ot.utils.list_to_array(*lst) + + assert a1.shape == (3,) + assert a2.shape == (3,) + + a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]]) + + def test_check_params(): res1 = ot.utils.check_params(first='OK', second=20) From 3e05385e2ac3b9e84b22b1f78b463ea851f8635a Mon Sep 17 00:00:00 2001 From: KrzakalaPaul <62666596+KrzakalaPaul@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:41:17 +0100 Subject: [PATCH 08/30] [MRG] fix the sign of gradient for kl gromov (#610) * fix the sign of gradient for kl gromov * releases updated * add PR ref --- RELEASES.md | 1 + ot/gromov/_gw.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 3b8513dbd..9b9d2f597 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) +- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) ## 0.9.2 *December 2023* diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 281ed5f0b..46e1ddfe8 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -315,7 +315,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) gw = nx.set_gradients(gw, (p, q, C1, C2), (log_gw['u'] - nx.mean(log_gw['u']), @@ -627,7 +627,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), (log_fgw['u'] - nx.mean(log_fgw['u']), From 63e44e5dfc51acf208ee088d65c980945c7da8b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 4 Mar 2024 21:23:58 +0100 Subject: [PATCH 09/30] [MRG] fix grad sign sr(F)GW with KL loss (#611) * fix grad sign * mrg --- RELEASES.md | 2 ++ ot/gromov/_semirelaxed.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 9b9d2f597..9734ab21a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,8 @@ - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) - Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) +- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) + ## 0.9.2 *December 2023* diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index c37ba2bf4..0137a8ed8 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -250,7 +250,7 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) @@ -509,7 +509,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), From ab12dd6606122dc7804f69b18eaec19adfca9c71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Mar 2024 17:24:31 +0100 Subject: [PATCH 10/30] [MRG] Implement the continuous entropic mapping (#613) * implemente pooladian mappng * update stuff * have corect normalization for entropic mapping * implement batches * improve coverage * comments cedric + pep8 * move it back --- README.md | 2 + benchmarks/__init__.py | 2 +- benchmarks/emd.py | 2 +- benchmarks/sinkhorn_knopp.py | 2 +- docs/nb_run_conv | 63 ++++++++------- docs/rtd/conf.py | 2 +- docs/source/conf.py | 20 ++--- ot/__init__.py | 2 +- ot/da.py | 150 +++++++++++++++++++++++++++++++++-- ot/gnn/__init__.py | 6 +- ot/lp/__init__.py | 21 +++-- ot/utils.py | 15 +++- setup.py | 1 + test/test_da.py | 20 +++++ 14 files changed, 235 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index 5fa4a6faa..88dce689a 100644 --- a/README.md +++ b/README.md @@ -355,3 +355,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). + +[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 37f5e569a..9dc687db4 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -2,4 +2,4 @@ from . import sinkhorn_knopp from . import emd -__all__= ["benchmark", "sinkhorn_knopp", "emd"] +__all__ = ["benchmark", "sinkhorn_knopp", "emd"] diff --git a/benchmarks/emd.py b/benchmarks/emd.py index 9f6486300..861dab332 100644 --- a/benchmarks/emd.py +++ b/benchmarks/emd.py @@ -34,7 +34,7 @@ def setup(n_samples): warmup_runs=warmup_runs ) print(convert_to_html_table( - results, + results, param_name="Sample size", main_title=f"EMD - Averaged on {n_runs} runs" )) diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index 3a1ef3f37..ef0f22b90 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -36,7 +36,7 @@ def setup(n_samples): warmup_runs=warmup_runs ) print(convert_to_html_table( - results, + results, param_name="Sample size", main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs" )) diff --git a/docs/nb_run_conv b/docs/nb_run_conv index ad5e432d3..adb47ace0 100755 --- a/docs/nb_run_conv +++ b/docs/nb_run_conv @@ -17,22 +17,24 @@ import subprocess import os -cache_file='cache_nbrun' +cache_file = 'cache_nbrun' + +path_doc = 'source/auto_examples/' +path_nb = '../notebooks/' -path_doc='source/auto_examples/' -path_nb='../notebooks/' def load_json(fname): try: - f=open(fname) - nb=json.load(f) + f = open(fname) + nb = json.load(f) f.close() - except (OSError, IOError) : - nb={} + except (OSError, IOError): + nb = {} return nb -def save_json(fname,nb): - f=open(fname,'w') + +def save_json(fname, nb): + f = open(fname, 'w') f.write(json.dumps(nb)) f.close() @@ -44,39 +46,36 @@ def md5(fname): hash_md5.update(chunk) return hash_md5.hexdigest() -def to_update(fname,cache): + +def to_update(fname, cache): if fname in cache: - if md5(path_doc+fname)==cache[fname]: - res=False + if md5(path_doc + fname) == cache[fname]: + res = False else: - res=True + res = True else: - res=True - + res = True + return res -def update(fname,cache): - + +def update(fname, cache): + # jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte - subprocess.check_call(['cp',path_doc+fname,path_nb]) - print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])) - subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']) - cache[fname]=md5(path_doc+fname) - + subprocess.check_call(['cp', path_doc + fname, path_nb]) + print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])) + subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']) + cache[fname] = md5(path_doc + fname) -cache=load_json(cache_file) +cache = load_json(cache_file) -lst_file=glob.glob(path_doc+'*.ipynb') +lst_file = glob.glob(path_doc + '*.ipynb') -lst_file=[os.path.basename(name) for name in lst_file] +lst_file = [os.path.basename(name) for name in lst_file] for fname in lst_file: - if to_update(fname,cache): + if to_update(fname, cache): print('Updating file: {}'.format(fname)) - update(fname,cache) - save_json(cache_file,cache) - - - - + update(fname, cache) + save_json(cache_file, cache) diff --git a/docs/rtd/conf.py b/docs/rtd/conf.py index 814db75a8..cf6479bf5 100644 --- a/docs/rtd/conf.py +++ b/docs/rtd/conf.py @@ -3,4 +3,4 @@ source_parsers = {'.md': CommonMarkParser} source_suffix = ['.md'] -master_doc = 'index' \ No newline at end of file +master_doc = 'index' diff --git a/docs/source/conf.py b/docs/source/conf.py index 6452cf857..c51b96ec4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,16 +22,12 @@ print("warning sphinx-gallery not installed") - - - - # !!!! allow readthedoc compilation try: from unittest.mock import MagicMock except ImportError: from mock import Mock as MagicMock - ## check whether in the source directory... + # check whether in the source directory... # @@ -42,7 +38,7 @@ def __getattr__(cls, name): return MagicMock() -MOCK_MODULES = [ 'cupy'] +MOCK_MODULES = ['cupy'] # 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers', sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # !!!! @@ -357,12 +353,12 @@ def __getattr__(cls, name): sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../examples/da'], 'gallery_dirs': 'auto_examples', - 'filename_pattern': 'plot_', #(?!barycenter_fgw) - 'nested_sections' : False, - 'backreferences_dir': 'gen_modules/backreferences', - 'inspect_global_variables' : True, - 'doc_module' : ('ot','numpy','scipy','pylab'), + 'filename_pattern': 'plot_', # (?!barycenter_fgw) + 'nested_sections': False, + 'backreferences_dir': 'gen_modules/backreferences', + 'inspect_global_variables': True, + 'doc_module': ('ot', 'numpy', 'scipy', 'pylab'), 'matplotlib_animations': True, 'reference_url': { - 'ot': None} + 'ot': None} } diff --git a/ot/__init__.py b/ot/__init__.py index 9a63b5f6f..1c10efafd 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -68,7 +68,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', diff --git a/ot/da.py b/ot/da.py index 4f3d3bb96..e4adaa546 100644 --- a/ot/da.py +++ b/ot/da.py @@ -493,7 +493,7 @@ class label # pairwise distance self.cost_ = dist(Xs, Xt, metric=self.metric) - self.cost_ = cost_normalization(self.cost_, self.norm) + self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True) if (ys is not None) and (yt is not None): @@ -1055,13 +1055,18 @@ class SinkhornTransport(BaseTransport): The ground metric for the Wasserstein problem norm : string, optional (default=None) If given, normalize the ground metric to avoid numerical errors that - can occur with large metric values. + can occur with large metric values. Accepted values are 'median', + 'max', 'log' and 'loglog'. distribution_estimation : callable, optional (defaults to the uniform) The kind of distribution estimation to employ - out_of_sample_map : string, optional (default="ferradans") + out_of_sample_map : string, optional (default="continuous") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in :ref:`[6] `. + "ferradans" which uses the nearest neighbor method proposed in :ref:`[6] + ` while "continuous" use the out of sample + method from :ref:`[66] + ` and :ref:`[19] + `. limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an cost defined @@ -1089,13 +1094,26 @@ class SinkhornTransport(BaseTransport): .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. + & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. + International Conference on Learning Representation (2018) + + .. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic + estimation of optimal transport maps." arXiv preprint + arXiv:2109.12004 (2021). + """ - def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000, + def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=np.infty): + out_of_sample_map='continuous', limit_max=np.infty): + + if out_of_sample_map not in ['ferradans', 'continuous']: + raise ValueError('Unknown out_of_sample_map method') + self.reg_e = reg_e self.method = method self.max_iter = max_iter @@ -1135,6 +1153,12 @@ class label super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) + if self.out_of_sample_map == 'continuous': + self.log = True + if not self.method == 'sinkhorn_log': + self.method = 'sinkhorn_log' + warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'") + # coupling estimation returned_ = sinkhorn( a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, @@ -1150,6 +1174,120 @@ class label return self + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The source input samples. + ys : array-like, shape (n_source_samples,) + The class labels for source samples + Xt : array-like, shape (n_target_samples, n_features) + The target input samples. + yt : array-like, shape (n_target_samples,) + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + nx = self.nx + + if self.out_of_sample_map == 'ferradans': + return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size) + + else: # self.out_of_sample_map == 'continuous': + + # check the necessary inputs parameters are here + g = self.log_['log_v'] + + indices = nx.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + for bi in batch_ind: + # get the nearest neighbor in the source domain + M = dist(Xs[bi], self.xt_, metric=self.metric) + + M = cost_normalization(M, self.norm, value=self.norm_cost_) + + K = nx.exp(-M / self.reg_e + g[None, :]) + + transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None] + + transp_Xs.append(transp_Xs_) + + transp_Xs = nx.concatenate(transp_Xs, axis=0) + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The source input samples. + ys : array-like, shape (n_source_samples,) + The class labels for source samples + Xt : array-like, shape (n_target_samples, n_features) + The target input samples. + yt : array-like, shape (n_target_samples,) + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transport target samples. + """ + + nx = self.nx + + if self.out_of_sample_map == 'ferradans': + return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size) + + else: # self.out_of_sample_map == 'continuous': + + f = self.log_['log_u'] + + indices = nx.arange(Xt.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size + )] + + transp_Xt = [] + for bi in batch_ind: + + M = dist(Xt[bi], self.xs_, metric=self.metric) + M = cost_normalization(M, self.norm, value=self.norm_cost_) + + K = nx.exp(-M / self.reg_e + f[None, :]) + + transp_Xt_ = nx.dot(K, self.xs_) / nx.sum(K, axis=1)[:, None] + + transp_Xt.append(transp_Xt_) + + transp_Xt = nx.concatenate(transp_Xt, axis=0) + + return transp_Xt + class EMDTransport(BaseTransport): diff --git a/ot/gnn/__init__.py b/ot/gnn/__init__.py index 6a84100a1..af39db6d2 100644 --- a/ot/gnn/__init__.py +++ b/ot/gnn/__init__.py @@ -17,8 +17,8 @@ # All submodules and packages -from ._utils import (FGW_distance_to_templates,wasserstein_distance_to_templates) +from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates) -from ._layers import (TFGWPooling,TWPooling) +from ._layers import (TFGWPooling, TWPooling) -__all__ = [ 'FGW_distance_to_templates', 'wasserstein_distance_to_templates','TFGWPooling','TWPooling'] \ No newline at end of file +__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 93316a6c1..752c5d2d7 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -21,8 +21,8 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein_circle, +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array @@ -262,7 +262,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - + Returns ------- @@ -341,8 +341,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c # ensure that same mass if check_marginals: np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum', - decimal=6) + b.sum(0), err_msg='a and b vector must have the same sum', + decimal=6) b = b * a.sum() / b.sum() asel = a != 0 @@ -440,8 +440,8 @@ def emd2(a, b, M, processes=1, check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - - + + Returns ------- W: float, array-like @@ -506,16 +506,15 @@ def emd2(a, b, M, processes=1, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') - assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass if check_marginals: np.testing.assert_almost_equal(a.sum(0), - b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum', - decimal=6) - b = b * a.sum(0) / b.sum(0,keepdims=True) + b.sum(0, keepdims=True), err_msg='a and b vector must have the same sum', + decimal=6) + b = b * a.sum(0) / b.sum(0, keepdims=True) asel = a != 0 diff --git a/ot/utils.py b/ot/utils.py index 404a9f2db..04c0e550e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -360,7 +360,7 @@ def dist0(n, method='lin_square'): return res -def cost_normalization(C, norm=None): +def cost_normalization(C, norm=None, return_value=False, value=None): r""" Apply normalization to the loss matrix Parameters @@ -382,9 +382,13 @@ def cost_normalization(C, norm=None): if norm is None: pass elif norm == "median": - C /= float(nx.median(C)) + if value is None: + value = nx.median(C) + C /= value elif norm == "max": - C /= float(nx.max(C)) + if value is None: + value = nx.max(C) + C /= float(value) elif norm == "log": C = nx.log(1 + C) elif norm == "loglog": @@ -393,7 +397,10 @@ def cost_normalization(C, norm=None): raise ValueError('Norm %s is not a valid option.\n' 'Valid options are:\n' 'median, max, log, loglog' % norm) - return C + if return_value: + return C, value + else: + return C def dots(*args): diff --git a/setup.py b/setup.py index 201e89c65..72b1488b2 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python + import os import re import subprocess diff --git a/test/test_da.py b/test/test_da.py index 37b709473..0e51bda22 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -346,6 +346,26 @@ def test_sinkhorn_transport_class(nx): otda.fit(Xs=Xs, ys=ys, Xt=Xt) assert len(otda.log_.keys()) != 0 + # test diffeernt transform and inverse transform + otda = ot.da.SinkhornTransport(out_of_sample_map='ferradans') + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + # test diffeernt transform + otda = ot.da.SinkhornTransport(out_of_sample_map='continuous', method='sinkhorn') + transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs2.shape, Xs.shape) + transp_Xt2 = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt2.shape, Xt.shape) + + np.testing.assert_almost_equal(nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5) + np.testing.assert_almost_equal(nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5) + + with pytest.raises(ValueError): + otda = ot.da.SinkhornTransport(out_of_sample_map='unknown') + @pytest.skip_backend("jax") @pytest.skip_backend("tf") From e75c9af61fa15f519c1d95f4d5b2115a3c808c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 17 Apr 2024 08:01:37 +0200 Subject: [PATCH 11/30] [MRG] split gromov test file (#619) * split gromov test file * fix pep8 * remove old test file * updates --- RELEASES.md | 2 +- test/gromov/__init__.py | 0 test/gromov/test_bregman.py | 919 ++++++++++ test/gromov/test_dictionary.py | 529 ++++++ test/gromov/test_estimators.py | 110 ++ test/gromov/test_gw.py | 834 +++++++++ test/gromov/test_semirelaxed.py | 615 +++++++ test/test_gromov.py | 2963 ------------------------------- 8 files changed, 3008 insertions(+), 2964 deletions(-) create mode 100644 test/gromov/__init__.py create mode 100644 test/gromov/test_bregman.py create mode 100644 test/gromov/test_dictionary.py create mode 100644 test/gromov/test_estimators.py create mode 100644 test/gromov/test_gw.py create mode 100644 test/gromov/test_semirelaxed.py delete mode 100644 test/test_gromov.py diff --git a/RELEASES.md b/RELEASES.md index 9734ab21a..c7e3f598b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -12,7 +12,7 @@ - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) - Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - +- Split `test/test_gromov.py` into `test/gromov/` (PR #619) ## 0.9.2 *December 2023* diff --git a/test/gromov/__init__.py b/test/gromov/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/gromov/test_bregman.py b/test/gromov/test_bregman.py new file mode 100644 index 000000000..4baf3ce10 --- /dev/null +++ b/test/gromov/test_bregman.py @@ -0,0 +1,919 @@ +""" Tests for gromov._bregman.py """ + +# Author: Rémi Flamary +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest + +import ot + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_entropic_gromov(nx, loss_fun): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, None, q, loss_fun, symmetric=None, G0=G0, + epsilon=1e-2, max_iter=10, verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None, + epsilon=1e-2, max_iter=10, verbose=True, log=False + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_entropic_gromov2(nx, loss_fun): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + gw, log = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, None, loss_fun, symmetric=True, G0=None, + max_iter=10, epsilon=1e-2, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-2, log=True) + gwb = nx.to_numpy(gwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_proximal_gromov(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, None, q, loss_fun, symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + gw, log = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb = nx.to_numpy(gwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_asymmetric_entropic_gromov(nx): + n_samples = 10 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + G = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=5, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + epsilon=1e-1, max_iter=5, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + gw = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, None, None, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) + gwb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, log=False) + gwb = nx.to_numpy(gwb) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_gromov_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) + + for solver in ['PGD', 'PPA', 'BAPG']: + if solver == 'BAPG': + Gb = ot.gromov.BAPG_gromov_wasserstein( + C1b, C2b, pb, qb, max_iter=2, verbose=True) + gw_valb = ot.gromov.BAPG_gromov_wasserstein2( + C1b, C2b, pb, qb, max_iter=2, verbose=True) + else: + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +def test_BAPG_gromov(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + # complete test with marginal loss = True + marginal_loss = True + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + G, log = ot.gromov.BAPG_gromov_wasserstein( + C1, C2, None, q, loss_fun, symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, + verbose=True, log=True) + + G, log = ot.gromov.BAPG_gromov_wasserstein( + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, + verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, + log=False + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + with pytest.warns(UserWarning): + + gw = ot.gromov.BAPG_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=10, epsilon=1e-2, marginal_loss=marginal_loss, log=False) + + gw, log = ot.gromov.BAPG_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) + gwb, logb = ot.gromov.BAPG_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) + gwb = nx.to_numpy(gwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + marginal_loss = False + G, log = ot.gromov.BAPG_gromov_wasserstein( + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, + verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=False, G0=None, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, + log=False + )) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + rng = np.random.RandomState(42) + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, None, None, loss_fun, symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, verbose=True, log=True) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, verbose=True, log=False + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_proximal_fgw(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + rng = np.random.RandomState(42) + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_BAPG_fgw(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + rng = np.random.RandomState(42) + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + G, log = ot.gromov.BAPG_fused_gromov_wasserstein( + M, C1, C2, p, q, loss_fun=loss_fun, max_iter=1, log=True) + + # complete test with marginal loss = True + marginal_loss = True + + G, log = ot.gromov.BAPG_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) + Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + with pytest.warns(UserWarning): + + fgw = ot.gromov.BAPG_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=10, epsilon=1e-3, marginal_loss=marginal_loss, log=False) + + fgw, log = ot.gromov.BAPG_fused_gromov_wasserstein2( + M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, + max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) + fgwb, logb = ot.gromov.BAPG_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + # Tests with marginal_loss = False + marginal_loss = False + G, log = ot.gromov.BAPG_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=False, G0=G0, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) + Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=None, G0=None, + epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-02) # cf convergence gromov + + +def test_asymmetric_entropic_fgw(nx): + n_samples = 5 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + ys = rng.randn(n_samples, 2) + yt = ys[idx, :] + M = ot.dist(ys, yt) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + G = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + max_iter=5, epsilon=1e-1, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) + fgwb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, log=False) + fgwb = nx.to_numpy(fgwb) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + rng = np.random.RandomState(42) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + + xt = xs[::-1].copy() + + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp) + + for solver in ['PGD', 'PPA', 'BAPG']: + if solver == 'BAPG': + Gb = ot.gromov.BAPG_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, max_iter=2) + fgw_valb = ot.gromov.BAPG_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, max_iter=2) + + else: + Gb = ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) + fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_entropic_fgw_barycenter(nx): + ns = 5 + nt = 10 + + rng = np.random.RandomState(42) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 3 + p = ot.unif(n_samples) + + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) + + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], loss_fun, 0.1, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42, + solver='PPA', numItermax=10, log=True, symmetric=True, + ) + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', + 0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=True, warmstartT=True, random_state=42, + solver='PPA', numItermax=10, log=True, symmetric=True, + ) + + for stop_criterion in ['barycenter', 'loss']: + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', + epsilon=0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=True, warmstartT=True, random_state=42, solver='PPA', + numItermax=10, log=True, symmetric=True + ) + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], + 'square_loss', epsilon=0.1, max_iter=10, tol=1e-3, + stop_criterion=stop_criterion, verbose=False, warmstartT=True, + random_state=42, solver='PPA', numItermax=10, log=False, symmetric=True) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + # test with 'kl_loss' and log=True + # providing init_C, init_Y + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) + init_Yb = nx.from_numpy(init_Y) + + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1, True, + max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, + solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True + ) + Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', + 0.1, True, max_iter=10, tol=1e-3, verbose=False, warmstartT=False, + random_state=42, solver='PPA', numItermax=1, init_C=init_Cb, + init_Y=init_Yb, log=True) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature'])) + np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure'])) + + # add tests with fixed_structures or fixed_features + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) + init_Yb = nx.from_numpy(init_Y) + + fixed_structure, fixed_features = True, False + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + fixed_structure=fixed_structure, init_C=None, + fixed_features=fixed_features, p=None, max_iter=10, tol=1e-3 + ) + + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + fixed_structure=fixed_structure, init_C=init_Cb, + fixed_features=fixed_features, max_iter=10, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb, init_Cb) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + fixed_structure, fixed_features = False, True + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], + fixed_structure=fixed_structure, fixed_features=fixed_features, + init_Y=None, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], + fixed_structure=fixed_structure, fixed_features=fixed_features, + init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb, init_Yb) + + +@pytest.mark.filterwarnings("ignore:divide") +def test_gromov_entropic_barycenter(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 2 + p = ot.unif(n_samples) + + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) + + with pytest.raises(ValueError): + loss_fun = 'weird_loss_fun' + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], None, p, [.5, .5], loss_fun, 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 + ) + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=True, warmstartT=True, random_state=42 + ) + + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of entropic_gromov_barycenters with `log` on + for stop_criterion in ['barycenter', 'loss']: + Cb_, err_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', 1e-3, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, verbose=True, + random_state=42, log=True + ) + Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', + 1e-3, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=True, random_state=42, log=True + ) + Cbb_ = nx.to_numpy(Cbb_) + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + # test of entropic_gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, + init_C=init_C, log=True + ) + Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_Cb=init_Cb, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + + +def test_not_implemented_solver(): + # test sinkhorn + n_samples = 5 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + rng = np.random.RandomState(42) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + xt = xs[::-1].copy() + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + M = ot.dist(ys, yt) + + solver = 'not_implemented' + # entropic gw and fgw + with pytest.raises(ValueError): + ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + with pytest.raises(ValueError): + ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) diff --git a/test/gromov/test_dictionary.py b/test/gromov/test_dictionary.py new file mode 100644 index 000000000..5b73b0d07 --- /dev/null +++ b/test/gromov/test_dictionary.py @@ -0,0 +1,529 @@ +""" Tests for gromov._dictionary.py """ + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np + +import ot + + +def test_gromov_wasserstein_linear_unmixing(nx): + n = 4 + + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + p = ot.unif(n) + + C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p) + + tol = 10**(-5) + # Tests without regularization + reg = 0. + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + + reg = 0.001 + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 4 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_means = [C.mean() for C in Cs] + rng = np.random.RandomState(0) + Cdict_init = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + + Csb = nx.from_numpy(*Cs) + psb = nx.from_numpy(*ps) + qb, Cdict_initb = nx.from_numpy(q, Cdict_init) + + # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization + # > Compute initial reconstruction of samples on this random dictionary without backend + use_adam_optimizer = True + verbose = False + tol = 10**(-5) + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_init, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + initial_total_reconstruction += reconstruction + + # > Learn the dictionary using this init + Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, + epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary without backend + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction += reconstruction + + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b += reconstruction + + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. + Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, + random_state=0 + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # Test: Same after going through backend + Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, + verbose=verbose, random_state=0 + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b_bis += reconstruction + + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # and testing other optimization settings untested until now. + # We pass previously estimated dictionaries to speed up the process. + use_adam_optimizer = False + verbose = True + use_log = True + + Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, + verbose=verbose, random_state=0, + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # Test: Same after going through backend + Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, + verbose=verbose, random_state=0, + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b_bis2 += reconstruction + + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) + np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + + +def test_fused_gromov_wasserstein_linear_unmixing(nx): + + n = 4 + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + Ydict = np.stack([F, F]) + p = ot.unif(n) + + C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p) + + # Tests without regularization + reg = 0. + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + reg = 0.001 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_fused_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 4 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Ys = [F.copy() for _ in range(n_samples)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_structure_means = [C.mean() for C in Cs] + rng = np.random.RandomState(0) + Cdict_init = rng.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) + Ydict_init = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + + Csb = nx.from_numpy(*Cs) + Ysb = nx.from_numpy(*Ys) + psb = nx.from_numpy(*ps) + qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init) + + # Test: Compute initial reconstruction of samples on this random dictionary + alpha = 0.5 + use_adam_optimizer = True + verbose = False + tol = 1e-05 + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + initial_total_reconstruction += reconstruction + + # > Learn a dictionary using this given initialization and check that the reconstruction loss + # on the learned dictionary is lower than the one using its initialization. + Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction += reconstruction + # Compare both + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, + epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, + random_state=0 + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b += reconstruction + + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) + + # Test: Perform similar experiment without providing the initial dictionary being an optional input + Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, + random_state=0 + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # > Same after going through backend + Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, + random_state=0, + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b_bis += reconstruction + + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + + # Test: without using adam optimizer, with log and verbose set to True + use_adam_optimizer = False + verbose = True + use_log = True + + # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. + Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, + verbose=verbose, random_state=0, + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # > Same after going through backend + Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose, + random_state=0, + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + ) + total_reconstruction_b_bis2 += reconstruction + + # > Compare results with/without backend + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) + np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) diff --git a/test/gromov/test_estimators.py b/test/gromov/test_estimators.py new file mode 100644 index 000000000..ead427204 --- /dev/null +++ b/test/gromov/test_estimators.py @@ -0,0 +1,110 @@ +""" Tests for gromov._estimators.py """ + +# Author: Rémi Flamary +# Tanguy Kerdoncuff +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.backend import NumpyBackend + + +def test_pointwise_gromov(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + + def loss(x, y): + return np.abs(x - y) + + def lossb(x, y): + return nx.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08) + np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) + + np.testing.assert_allclose(G, Gb, atol=1e-06) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_sampled_gromov(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0], dtype=np.float64) + cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + + def loss(x, y): + return np.abs(x - y) + + def lossb(x, y): + return nx.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) + Gb, logb = ot.gromov.sampled_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(Gb) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py new file mode 100644 index 000000000..0008cebce --- /dev/null +++ b/test/gromov/test_gw.py @@ -0,0 +1,834 @@ +""" Tests for gromov._gw.py """ + +# Author: Erwan Vautier +# Nicolas Courty +# Titouan Vayer +# Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest +import warnings + +import ot +from ot.backend import torch, tf + + +def test_gromov(nx): + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G = ot.gromov.gromov_wasserstein( + C1, C2, None, q, 'square_loss', G0=G0, verbose=True, + alpha_min=0., alpha_max=1.) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) + for armijo in [False, True]: + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True) + gwb = nx.to_numpy(gwb) + + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False) + gw_valb = nx.to_numpy( + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False) + ) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) + + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_asymmetric_gromov(nx): + n_samples = 20 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + + +def test_gromov_integer_warnings(nx): + n_samples = 10 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + C1 = C1.astype(np.int32) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G = ot.gromov.gromov_wasserstein( + C1, C2, None, q, 'square_loss', G0=G0, verbose=True, + alpha_min=0., alpha_max=1.) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(G, 0., atol=1e-09) + + +def test_gromov_dtype_device(nx): + # setup + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) + + with warnings.catch_warnings(): + warnings.filterwarnings('error') + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_gromov_device_tf(): + nx = ot.backend.TensorflowBackend() + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + +def test_gromov2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + + # classical gradients + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + + # Test with exact line-search + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + # Test with armijo line-search + # classical gradients + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + + q1.grad = None + p1.grad = None + C11.grad = None + C12.grad = None + val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + +def test_gw_helper_backend(nx): + n_samples = 10 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', + 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_gw_helper_validation(loss_fun): + n_samples = 10 # nb samples + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun) + + +def test_gromov_barycenter(nx): + ns = 5 + nt = 8 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 3 + p = ot.unif(n_samples) + + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + for stop_criterion in ['barycenter', 'loss']: + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + warmstartT=True, random_state=42, log=True + ) + Cbb_, errb_ = ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, warmstartT=True, random_state=42, log=True + ) + Cbb_ = nx.to_numpy(Cbb_) + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + Cb2_, err2_ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=10, + tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C + ) + Cb2b_, err2b_ = ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', + max_iter=10, tol=1e-3, verbose=True, random_state=42, + init_C=init_Cb, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + + +def test_fgw(nx): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + rng = np.random.RandomState(42) + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) + Gb = nx.to_numpy(Gb) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence fgw + + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-08) + np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_asymmetric_fgw(nx): + n_samples = 20 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + # add features + F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = F1[idx, :] + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + M = ot.dist(F1, F2) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, + symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, + symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2( + M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, + symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, + symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + # Tests with kl-loss: + for armijo in [False, True]: + G, log = ot.gromov.fused_gromov_wasserstein( + M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, + log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, + log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + fgw, log = ot.gromov.fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, + symmetric=None, verbose=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, + symmetric=False, G0=G0b, verbose=True) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + + +def test_fgw_integer_warnings(nx): + n_samples = 20 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + # add features + F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = F1[idx, :] + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + M = ot.dist(F1, F2).astype(np.int32) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, + symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, + symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(G, 0., atol=1e-06) + + +def test_fgw2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + # full gradients with alpha + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape + + +def test_fgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + rng = np.random.RandomState(42) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = rng.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = rng.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + + def f(G): + return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) + + def df(G): + return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to cg + res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) + # feed the precomputed local optimum Gb to cg + res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + np.testing.assert_allclose(res_armijo, Gb, atol=1e-06) + + +def test_fgw_barycenter(nx): + ns = 10 + nt = 20 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + rng = np.random.RandomState(42) + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + + p1, p2 = ot.unif(ns), ot.unif(nt) + n_samples = 3 + p = ot.unif(n_samples) + + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) + lambdas = [.5, .5] + Csb = [C1b, C2b] + Ysb = [ysb, ytb] + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, + random_state=12345, log=True + ) + # test correspondance with utils function + recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb) + recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T + + np.testing.assert_allclose(Cb, recovered_Cb) + np.testing.assert_allclose(Xb, recovered_Xb) + + xalea = rng.randn(n_samples, 2) + init_C = ot.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 + ) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + init_X = rng.randn(n_samples, ys.shape[1]) + init_Xb = nx.from_numpy(init_X) + + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=None, + p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # add test with 'kl_loss' + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + X, C, log = ot.gromov.fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + init_X=X, warmstartT=True, random_state=12345, log=True + ) + + for stop_criterion in ['barycenter', 'loss']: + X, C, log = ot.gromov.fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # test correspondance with utils function + recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) + np.testing.assert_allclose(C, recovered_C) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py new file mode 100644 index 000000000..6f23a6b62 --- /dev/null +++ b/test/gromov/test_semirelaxed.py @@ -0,0 +1,615 @@ +""" Tests for gromov._semirelaxed.py """ + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.backend import torch + + +def test_semirelaxed_gromov(nx): + rng = np.random.RandomState(0) + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, + G0=None, alpha_min=0., alpha_max=1.) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + for loss_fun in ['square_loss', 'kl_loss']: + # semirelaxed solvers do not support gradients over masses yet. + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + +def test_srgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun) + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + +@pytest.mark.parametrize('loss_fun', [ + 'square_loss', 'kl_loss', + pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), +]) +def test_gw_semirelaxed_helper_validation(loss_fun): + n_samples = 20 # nb samples + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + p = ot.unif(n_samples) + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun) + + +def test_semirelaxed_fgw(nx): + rng = np.random.RandomState(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + for loss_fun in ['square_loss', 'kl_loss']: + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +def test_semirelaxed_fgw2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + # semirelaxed solvers do not support gradients over masses yet. + for loss_fun in ['square_loss', 'kl_loss']: + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + # full gradients with alpha + p1 = torch.tensor(p, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) + + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha) + + val.backward() + + assert val.device == p1.device + assert p1.grad is None + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape + + +def test_srfgw_helper_backend(nx): + n_samples = 20 # nb samples + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 1]]) + + rng = np.random.RandomState(42) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) + ys = rng.randn(xs.shape[0], 2) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) + yt = rng.randn(xt.shape[0], 2) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + M /= M.max() + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + alpha = 0.5 + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + + # calls with nx=None + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + ones_pb = nx.ones(pb.shape[0], type_as=pb) + + def f(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) + return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) + + def line_search(cost, G, deltaG, Mi, cost_G): + return ot.gromov.solve_semirelaxed_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) + # feed the precomputed local optimum Gb to semirelaxed_cg + res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + # check constraints + np.testing.assert_allclose(res, Gb, atol=1e-06) + + +def test_entropic_semirelaxed_gromov(nx): + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + rng = np.random.RandomState(0) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + epsilon = 0.1 + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_gromov_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + + print(nx.dtype_device(tp)) + for loss_fun in ['square_loss', 'kl_loss']: + C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) + + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +def test_entropic_semirelaxed_fgw(nx): + rng = np.random.RandomState(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + for loss_fun in ['square_loss', 'kl_loss']: + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + rng = np.random.RandomState(42) + ys = rng.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) + + for loss_fun in ['square_loss', 'kl_loss']: + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) diff --git a/test/test_gromov.py b/test/test_gromov.py deleted file mode 100644 index 83d65306b..000000000 --- a/test/test_gromov.py +++ /dev/null @@ -1,2963 +0,0 @@ -"""Tests for module gromov """ - -# Author: Erwan Vautier -# Nicolas Courty -# Titouan Vayer -# Cédric Vincent-Cuaz -# -# License: MIT License - -import numpy as np -import pytest -import warnings - -import ot -from ot.backend import NumpyBackend -from ot.backend import torch, tf - - -def test_gromov(nx): - n_samples = 20 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - G = ot.gromov.gromov_wasserstein( - C1, C2, None, q, 'square_loss', G0=G0, verbose=True, - alpha_min=0., alpha_max=1.) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - - np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) - for armijo in [False, True]: - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True) - gwb = nx.to_numpy(gwb) - - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False) - gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False) - ) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - - np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) - np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -def test_asymmetric_gromov(nx): - n_samples = 20 # nb samples - rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - idx = np.arange(n_samples) - rng.shuffle(idx) - C2 = C1[idx, :][:, idx] - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) - Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) - - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) - - -def test_gromov_integer_warnings(nx): - n_samples = 10 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - C1 = C1.astype(np.int32) - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - G = ot.gromov.gromov_wasserstein( - C1, C2, None, q, 'square_loss', G0=G0, verbose=True, - alpha_min=0., alpha_max=1.) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(G, 0., atol=1e-09) - - -def test_gromov_dtype_device(nx): - # setup - n_samples = 20 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) - - with warnings.catch_warnings(): - warnings.filterwarnings('error') - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) - - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) - - -@pytest.mark.skipif(not tf, reason="tf not installed") -def test_gromov_device_tf(): - nx = ot.backend.TensorflowBackend() - n_samples = 20 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - xt = xs[::-1].copy() - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - C1 /= C1.max() - C2 /= C2.max() - - # Check that everything stays on the CPU - with tf.device("/CPU:0"): - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) - - if len(tf.config.list_physical_devices('GPU')) > 0: - # Check that everything happens on the GPU - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) - assert nx.dtype_device(Gb)[1].startswith("GPU") - - -def test_gromov2_gradients(): - n_samples = 20 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - if torch: - - devices = [torch.device("cpu")] - if torch.cuda.is_available(): - devices.append(torch.device("cuda")) - for device in devices: - - # classical gradients - p1 = torch.tensor(p, requires_grad=True, device=device) - q1 = torch.tensor(q, requires_grad=True, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - - # Test with exact line-search - val = ot.gromov_wasserstein2(C11, C12, p1, q1) - - val.backward() - - assert val.device == p1.device - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - - # Test with armijo line-search - # classical gradients - p1 = torch.tensor(p, requires_grad=True, device=device) - q1 = torch.tensor(q, requires_grad=True, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - - q1.grad = None - p1.grad = None - C11.grad = None - C12.grad = None - val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True) - - val.backward() - - assert val.device == p1.device - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - - -def test_gw_helper_backend(nx): - n_samples = 10 # nb samples - - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True) - - # calls with nx=None - constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') - - def f(G): - return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) - - def df(G): - return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) - # feed the precomputed local optimum Gb to cg - res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) - - -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) -def test_gw_helper_validation(loss_fun): - n_samples = 10 # nb samples - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - p = ot.unif(n_samples) - q = ot.unif(n_samples) - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - ot.gromov.init_matrix(C1, C2, p, q, loss_fun=loss_fun) - - -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) -def test_entropic_gromov(nx, loss_fun): - n_samples = 10 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-2, max_iter=10, verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None, - epsilon=1e-2, max_iter=10, verbose=True, log=False - )) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) -def test_entropic_gromov2(nx, loss_fun): - n_samples = 10 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, None, loss_fun, symmetric=True, G0=None, - max_iter=10, epsilon=1e-2, log=True) - gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-2, log=True) - gwb = nx.to_numpy(gwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_proximal_gromov(nx): - n_samples = 10 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) - - G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 - )) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, - max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) - gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) - gwb = nx.to_numpy(gwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_asymmetric_entropic_gromov(nx): - n_samples = 10 # nb samples - rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - idx = np.arange(n_samples) - rng.shuffle(idx) - C2 = C1[idx, :][:, idx] - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=5, verbose=True, log=False) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, max_iter=5, verbose=True, log=False - )) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - gw = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, None, None, 'kl_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, log=False) - gwb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, log=False) - gwb = nx.to_numpy(gwb) - - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - - -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_gromov_dtype_device(nx): - # setup - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) - - for solver in ['PGD', 'PPA', 'BAPG']: - if solver == 'BAPG': - Gb = ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, qb, max_iter=2, verbose=True) - gw_valb = ot.gromov.BAPG_gromov_wasserstein2( - C1b, C2b, pb, qb, max_iter=2, verbose=True) - else: - Gb = ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) - gw_valb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) - - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) - - -def test_BAPG_gromov(nx): - n_samples = 10 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - - # complete test with marginal loss = True - marginal_loss = True - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) - - G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, - log=False - )) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - with pytest.warns(UserWarning): - - gw = ot.gromov.BAPG_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-2, marginal_loss=marginal_loss, log=False) - - gw, log = ot.gromov.BAPG_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) - gwb, logb = ot.gromov.BAPG_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) - gwb = nx.to_numpy(gwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(gw, gwb, atol=1e-06) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - marginal_loss = False - G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, - log=False - )) - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_fgw(nx): - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - rng = np.random.RandomState(42) - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, None, None, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, verbose=True, log=True) - - G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, verbose=True, log=False - )) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, - max_iter=10, epsilon=1e-1, log=True) - fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, log=True) - fgwb = nx.to_numpy(fgwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(fgw, fgwb, atol=1e-06) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_proximal_fgw(nx): - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - rng = np.random.RandomState(42) - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 - )) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, - max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) - fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) - fgwb = nx.to_numpy(fgwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(fgw, fgwb, atol=1e-06) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -def test_BAPG_fgw(nx): - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - rng = np.random.RandomState(42) - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, loss_fun=loss_fun, max_iter=1, log=True) - - # complete test with marginal loss = True - marginal_loss = True - - G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - with pytest.warns(UserWarning): - - fgw = ot.gromov.BAPG_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-3, marginal_loss=marginal_loss, log=False) - - fgw, log = ot.gromov.BAPG_fused_gromov_wasserstein2( - M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, - max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) - fgwb, logb = ot.gromov.BAPG_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) - fgwb = nx.to_numpy(fgwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(fgw, fgwb, atol=1e-06) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - # Tests with marginal_loss = False - marginal_loss = False - G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=False, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=None, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov - - -def test_asymmetric_entropic_fgw(nx): - n_samples = 5 # nb samples - rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - idx = np.arange(n_samples) - rng.shuffle(idx) - C2 = C1[idx, :][:, idx] - - ys = rng.randn(n_samples, 2) - yt = ys[idx, :] - M = ot.dist(ys, yt) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - max_iter=5, epsilon=1e-1, verbose=True, log=False) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, verbose=True, log=False - )) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - fgw = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, log=False) - fgwb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, log=False) - fgwb = nx.to_numpy(fgwb) - - np.testing.assert_allclose(fgw, fgwb, atol=1e-06) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - - -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_fgw_dtype_device(nx): - # setup - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - rng = np.random.RandomState(42) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) - - xt = xs[::-1].copy() - - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp) - - for solver in ['PGD', 'PPA', 'BAPG']: - if solver == 'BAPG': - Gb = ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, max_iter=2) - fgw_valb = ot.gromov.BAPG_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, max_iter=2) - - else: - Gb = ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) - fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) - - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, fgw_valb) - - -def test_entropic_fgw_barycenter(nx): - ns = 5 - nt = 10 - - rng = np.random.RandomState(42) - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) - - ys = rng.randn(Xs.shape[0], 2) - yt = rng.randn(Xt.shape[0], 2) - - C1 = ot.dist(Xs) - C2 = ot.dist(Xt) - p1 = ot.unif(ns) - p2 = ot.unif(nt) - n_samples = 3 - p = ot.unif(n_samples) - - ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) - - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], loss_fun, 0.1, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42, - solver='PPA', numItermax=10, log=True, symmetric=True, - ) - with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' - X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', - 0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42, - solver='PPA', numItermax=10, log=True, symmetric=True, - ) - - for stop_criterion in ['barycenter', 'loss']: - X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', - epsilon=0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42, solver='PPA', - numItermax=10, log=True, symmetric=True - ) - Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], - 'square_loss', epsilon=0.1, max_iter=10, tol=1e-3, - stop_criterion=stop_criterion, verbose=False, warmstartT=True, - random_state=42, solver='PPA', numItermax=10, log=False, symmetric=True) - Xb, Cb = nx.to_numpy(Xb, Cb) - - np.testing.assert_allclose(C, Cb, atol=1e-06) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X, Xb, atol=1e-06) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) - - # test with 'kl_loss' and log=True - # providing init_C, init_Y - generator = ot.utils.check_random_state(42) - xalea = generator.randn(n_samples, 2) - init_C = ot.utils.dist(xalea, xalea) - init_C /= init_C.max() - init_Cb = nx.from_numpy(init_C) - - init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) - init_Yb = nx.from_numpy(init_Y) - - X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1, True, - max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, - solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True - ) - Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', - 0.1, True, max_iter=10, tol=1e-3, verbose=False, warmstartT=False, - random_state=42, solver='PPA', numItermax=1, init_C=init_Cb, - init_Y=init_Yb, log=True) - Xb, Cb = nx.to_numpy(Xb, Cb) - - np.testing.assert_allclose(C, Cb, atol=1e-06) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X, Xb, atol=1e-06) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) - np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature'])) - np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure'])) - - # add tests with fixed_structures or fixed_features - init_C = ot.utils.dist(xalea, xalea) - init_C /= init_C.max() - init_Cb = nx.from_numpy(init_C) - - init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) - init_Yb = nx.from_numpy(init_Y) - - fixed_structure, fixed_features = True, False - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` - Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - fixed_structure=fixed_structure, init_C=None, - fixed_features=fixed_features, p=None, max_iter=10, tol=1e-3 - ) - - Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - fixed_structure=fixed_structure, init_C=init_Cb, - fixed_features=fixed_features, max_iter=10, tol=1e-3 - ) - Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(Cb, init_Cb) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) - - fixed_structure, fixed_features = False, True - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` - Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], - fixed_structure=fixed_structure, fixed_features=fixed_features, - init_Y=None, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) - Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], - fixed_structure=fixed_structure, fixed_features=fixed_features, - init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) - - X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(Xb, init_Yb) - - -def test_pointwise_gromov(nx): - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - - def loss(x, y): - return np.abs(x - y) - - def lossb(x, y): - return nx.abs(x - y) - - G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) - G = NumpyBackend().todense(G) - Gb, logb = ot.gromov.pointwise_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) - Gb = nx.to_numpy(nx.todense(Gb)) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08) - - G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) - G = NumpyBackend().todense(G) - Gb, logb = ot.gromov.pointwise_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) - Gb = nx.to_numpy(nx.todense(Gb)) - - np.testing.assert_allclose(G, Gb, atol=1e-06) - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -@pytest.skip_backend("jax", reason="test very slow with jax backend") -def test_sampled_gromov(nx): - n_samples = 5 # nb samples - - mu_s = np.array([0, 0], dtype=np.float64) - cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - - def loss(x, y): - return np.abs(x - y) - - def lossb(x, y): - return nx.abs(x - y) - - G, log = ot.gromov.sampled_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) - Gb, logb = ot.gromov.sampled_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) - Gb = nx.to_numpy(Gb) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -def test_gromov_barycenter(nx): - ns = 5 - nt = 8 - - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) - - C1 = ot.dist(Xs) - C2 = ot.dist(Xt) - p1 = ot.unif(ns) - p2 = ot.unif(nt) - n_samples = 3 - p = ot.unif(n_samples) - - C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) - with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' - Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 - ) - - for stop_criterion in ['barycenter', 'loss']: - Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 - ) - Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42 - )) - np.testing.assert_allclose(Cb, Cbb, atol=1e-06) - np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) - - # test of gromov_barycenters with `log` on - Cb_, err_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - warmstartT=True, random_state=42, log=True - ) - Cbb_, errb_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, warmstartT=True, random_state=42, log=True - ) - Cbb_ = nx.to_numpy(Cbb_) - np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) - np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 - ) - Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 - )) - np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) - np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) - - # test of gromov_barycenters with `log` on - # providing init_C - generator = ot.utils.check_random_state(42) - xalea = generator.randn(n_samples, 2) - init_C = ot.utils.dist(xalea, xalea) - init_C /= init_C.max() - init_Cb = nx.from_numpy(init_C) - - Cb2_, err2_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=10, - tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C - ) - Cb2b_, err2b_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', - max_iter=10, tol=1e-3, verbose=True, random_state=42, - init_C=init_Cb, log=True - ) - Cb2b_ = nx.to_numpy(Cb2b_) - np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) - np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) - - -@pytest.mark.filterwarnings("ignore:divide") -def test_gromov_entropic_barycenter(nx): - ns = 5 - nt = 10 - - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) - - C1 = ot.dist(Xs) - C2 = ot.dist(Xt) - p1 = ot.unif(ns) - p2 = ot.unif(nt) - n_samples = 2 - p = ot.unif(n_samples) - - C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) - - with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' - Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], loss_fun, 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 - ) - with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' - Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42 - ) - - Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 - ) - Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 - )) - np.testing.assert_allclose(Cb, Cbb, atol=1e-06) - np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) - - # test of entropic_gromov_barycenters with `log` on - for stop_criterion in ['barycenter', 'loss']: - Cb_, err_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', 1e-3, - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, verbose=True, - random_state=42, log=True - ) - Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', - 1e-3, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, random_state=42, log=True - ) - Cbb_ = nx.to_numpy(Cbb_) - np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) - np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 - ) - Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 - )) - np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) - np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) - - # test of entropic_gromov_barycenters with `log` on - # providing init_C - generator = ot.utils.check_random_state(42) - xalea = generator.randn(n_samples, 2) - init_C = ot.utils.dist(xalea, xalea) - init_C /= init_C.max() - init_Cb = nx.from_numpy(init_C) - - Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, - max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, - init_C=init_C, log=True - ) - Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True, - random_state=42, init_Cb=init_Cb, log=True - ) - Cb2b_ = nx.to_numpy(Cb2b_) - np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) - np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) - - -def test_fgw(nx): - n_samples = 20 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - rng = np.random.RandomState(42) - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - M /= M.max() - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) - Gb = nx.to_numpy(Gb) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence fgw - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence fgw - - Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - - np.testing.assert_allclose( - Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) - fgwb = nx.to_numpy(fgwb) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - - np.testing.assert_allclose(fgw, fgwb, atol=1e-08) - np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - -def test_asymmetric_fgw(nx): - n_samples = 20 # nb samples - rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - idx = np.arange(n_samples) - rng.shuffle(idx) - C2 = C1[idx, :][:, idx] - - # add features - F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) - F2 = F1[idx, :] - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - M = ot.dist(F1, F2) - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=False, verbose=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=None, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - fgw, log = ot.gromov.fused_gromov_wasserstein2( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=None, verbose=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=False, G0=G0b, verbose=True) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - # Tests with kl-loss: - for armijo in [False, True]: - G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, - log=True, symmetric=False, verbose=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, - log=True, symmetric=None, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - fgw, log = ot.gromov.fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, - symmetric=None, verbose=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, - symmetric=False, G0=G0b, verbose=True) - - G = log['T'] - Gb = nx.to_numpy(logb['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) - - -def test_fgw_integer_warnings(nx): - n_samples = 20 # nb samples - rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - idx = np.arange(n_samples) - rng.shuffle(idx) - C2 = C1[idx, :][:, idx] - - # add features - F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) - F2 = F1[idx, :] - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - M = ot.dist(F1, F2).astype(np.int32) - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - - G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=False, verbose=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=None, G0=G0b, verbose=True) - Gb = nx.to_numpy(Gb) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(G, 0., atol=1e-06) - - -def test_fgw2_gradients(): - n_samples = 20 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - M = ot.dist(xs, xt) - - C1 /= C1.max() - C2 /= C2.max() - - if torch: - - devices = [torch.device("cpu")] - if torch.cuda.is_available(): - devices.append(torch.device("cuda")) - for device in devices: - p1 = torch.tensor(p, requires_grad=True, device=device) - q1 = torch.tensor(q, requires_grad=True, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert val.device == p1.device - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape - - # full gradients with alpha - p1 = torch.tensor(p, requires_grad=True, device=device) - q1 = torch.tensor(q, requires_grad=True, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - alpha = torch.tensor(0.5, requires_grad=True, device=device) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha) - - val.backward() - - assert val.device == p1.device - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert alpha.shape == alpha.grad.shape - - -def test_fgw_helper_backend(nx): - n_samples = 20 # nb samples - - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - - rng = np.random.RandomState(42) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - ys = rng.randn(xs.shape[0], 2) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - yt = rng.randn(xt.shape[0], 2) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - M /= M.max() - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - alpha = 0.5 - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) - - # calls with nx=None - constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') - - def f(G): - return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) - - def df(G): - return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) - # feed the precomputed local optimum Gb to cg - res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) - # feed the precomputed local optimum Gb to cg - res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) - np.testing.assert_allclose(res_armijo, Gb, atol=1e-06) - - -def test_fgw_barycenter(nx): - ns = 10 - nt = 20 - - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) - - rng = np.random.RandomState(42) - ys = rng.randn(Xs.shape[0], 2) - yt = rng.randn(Xt.shape[0], 2) - - C1 = ot.dist(Xs) - C2 = ot.dist(Xt) - C1 /= C1.max() - C2 /= C2.max() - - p1, p2 = ot.unif(ns), ot.unif(nt) - n_samples = 3 - p = ot.unif(n_samples) - - ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) - lambdas = [.5, .5] - Csb = [C1b, C2b] - Ysb = [ysb, ytb] - Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, - fixed_features=False, p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - random_state=12345, log=True - ) - # test correspondance with utils function - recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb) - recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T - - np.testing.assert_allclose(Cb, recovered_Cb) - np.testing.assert_allclose(Xb, recovered_Xb) - - xalea = rng.randn(n_samples, 2) - init_C = ot.dist(xalea, xalea) - init_C /= init_C.max() - init_Cb = nx.from_numpy(init_C) - - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` - Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 - ) - - Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 - ) - Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) - - init_X = rng.randn(n_samples, ys.shape[1]) - init_Xb = nx.from_numpy(init_X) - - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` - Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=None, - p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) - Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) - - X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) - - # add test with 'kl_loss' - with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' - X, C, log = ot.gromov.fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True - ) - - for stop_criterion in ['barycenter', 'loss']: - X, C, log = ot.gromov.fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True - ) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) - - # test correspondance with utils function - recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) - np.testing.assert_allclose(C, recovered_C) - - -def test_gromov_wasserstein_linear_unmixing(nx): - n = 4 - - X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) - - C1 = ot.dist(X1) - C2 = ot.dist(X2) - Cdict = np.stack([C1, C2]) - p = ot.unif(n) - - C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p) - - tol = 10**(-5) - # Tests without regularization - reg = 0. - unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( - C1, Cdict, reg=reg, p=p, q=p, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( - C1b, Cdictb, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( - C2, Cdict, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( - C2b, Cdictb, reg=reg, p=pb, q=pb, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01) - np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) - np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) - np.testing.assert_allclose(C1b_emb.shape, (n, n)) - np.testing.assert_allclose(C2b_emb.shape, (n, n)) - - # Tests with regularization - - reg = 0.001 - unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( - C1, Cdict, reg=reg, p=p, q=p, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( - C1b, Cdictb, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( - C2, Cdict, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( - C2b, Cdictb, reg=reg, p=pb, q=pb, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 - ) - - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) - np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) - np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) - np.testing.assert_allclose(C1b_emb.shape, (n, n)) - np.testing.assert_allclose(C2b_emb.shape, (n, n)) - - -def test_gromov_wasserstein_dictionary_learning(nx): - - # create dataset composed from 2 structures which are repeated 5 times - shape = 4 - n_samples = 2 - n_atoms = 2 - projection = 'nonnegative_symmetric' - X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) - C1 = ot.dist(X1) - C2 = ot.dist(X2) - Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] - ps = [ot.unif(shape) for _ in range(n_samples)] - q = ot.unif(shape) - - # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) - # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. - dataset_means = [C.mean() for C in Cs] - rng = np.random.RandomState(0) - Cdict_init = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) - - if projection == 'nonnegative_symmetric': - Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) - Cdict_init[Cdict_init < 0.] = 0. - - Csb = nx.from_numpy(*Cs) - psb = nx.from_numpy(*ps) - qb, Cdict_initb = nx.from_numpy(q, Cdict_init) - - # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization - # > Compute initial reconstruction of samples on this random dictionary without backend - use_adam_optimizer = True - verbose = False - tol = 10**(-5) - epochs = 1 - - initial_total_reconstruction = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_init, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - initial_total_reconstruction += reconstruction - - # > Learn the dictionary using this init - Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, - epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose - ) - # > Compute reconstruction of samples on learned dictionary without backend - total_reconstruction = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction += reconstruction - - np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) - - # Test: Perform same experiments after going through backend - - Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose - ) - # Compute reconstruction of samples on learned dictionary - total_reconstruction_b = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b += reconstruction - - total_reconstruction_b = nx.to_numpy(total_reconstruction_b) - np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) - np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) - np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) - np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) - - # Test: Perform same comparison without providing the initial dictionary being an optional input - # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. - Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_bis = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_bis += reconstruction - - np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) - - # Test: Same after going through backend - Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0 - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_b_bis = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb_bis, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b_bis += reconstruction - - total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) - np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) - np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) - - # Test: Perform same comparison without providing the initial dictionary being an optional input - # and testing other optimization settings untested until now. - # We pass previously estimated dictionaries to speed up the process. - use_adam_optimizer = False - verbose = True - use_log = True - - Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, - epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_bis2 = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_bis2 += reconstruction - - np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) - - # Test: Same after going through backend - Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, - epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_b_bis2 = 0 - for i in range(n_samples): - _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b_bis2 += reconstruction - - total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) - np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) - - -def test_fused_gromov_wasserstein_linear_unmixing(nx): - - n = 4 - X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) - F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) - - C1 = ot.dist(X1) - C2 = ot.dist(X2) - Cdict = np.stack([C1, C2]) - Ydict = np.stack([F, F]) - p = ot.unif(n) - - C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p) - - # Tests without regularization - reg = 0. - - unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01) - np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) - np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) - np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) - np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) - np.testing.assert_allclose(C1b_emb.shape, (n, n)) - np.testing.assert_allclose(C2b_emb.shape, (n, n)) - - # Tests with regularization - reg = 0.001 - - unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 - ) - - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) - np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) - np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) - np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) - np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) - np.testing.assert_allclose(C1b_emb.shape, (n, n)) - np.testing.assert_allclose(C2b_emb.shape, (n, n)) - - -def test_fused_gromov_wasserstein_dictionary_learning(nx): - - # create dataset composed from 2 structures which are repeated 5 times - shape = 4 - n_samples = 2 - n_atoms = 2 - projection = 'nonnegative_symmetric' - X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) - F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) - - C1 = ot.dist(X1) - C2 = ot.dist(X2) - Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] - Ys = [F.copy() for _ in range(n_samples)] - ps = [ot.unif(shape) for _ in range(n_samples)] - q = ot.unif(shape) - - # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) - # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. - dataset_structure_means = [C.mean() for C in Cs] - rng = np.random.RandomState(0) - Cdict_init = rng.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) - if projection == 'nonnegative_symmetric': - Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) - Cdict_init[Cdict_init < 0.] = 0. - dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) - Ydict_init = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) - - Csb = nx.from_numpy(*Cs) - Ysb = nx.from_numpy(*Ys) - psb = nx.from_numpy(*ps) - qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init) - - # Test: Compute initial reconstruction of samples on this random dictionary - alpha = 0.5 - use_adam_optimizer = True - verbose = False - tol = 1e-05 - epochs = 1 - - initial_total_reconstruction = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, - alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - initial_total_reconstruction += reconstruction - - # > Learn a dictionary using this given initialization and check that the reconstruction loss - # on the learned dictionary is lower than the one using its initialization. - Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction += reconstruction - # Compare both - np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) - - # Test: Perform same experiments after going through backend - Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, - epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_b = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b += reconstruction - - total_reconstruction_b = nx.to_numpy(total_reconstruction_b) - np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) - np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) - np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) - np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) - - # Test: Perform similar experiment without providing the initial dictionary being an optional input - Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_bis = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_bis += reconstruction - - np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) - - # > Same after going through backend - Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0, - ) - - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_b_bis = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b_bis += reconstruction - - total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) - np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) - - # Test: without using adam optimizer, with log and verbose set to True - use_adam_optimizer = False - verbose = True - use_log = True - - # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. - Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, - epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, - ) - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_bis2 = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_bis2 += reconstruction - - np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) - - # > Same after going through backend - Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, - epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0, - ) - - # > Compute reconstruction of samples on learned dictionary - total_reconstruction_b_bis2 = 0 - for i in range(n_samples): - _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 - ) - total_reconstruction_b_bis2 += reconstruction - - # > Compare results with/without backend - total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) - np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) - - -def test_semirelaxed_gromov(nx): - rng = np.random.RandomState(0) - # unbalanced proportions - list_n = [30, 15] - nt = 2 - ns = np.sum(list_n) - # create directed sbm with C2 as connectivity matrix - C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) - for i in range(nt): - for j in range(nt): - ni, nj = list_n[i], list_n[j] - xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij - p = ot.unif(ns, type_as=C1) - q0 = ot.unif(C2.shape[0], type_as=C1) - G0 = p[:, None] * q0[None, :] - # asymmetric - C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, - G0=None, alpha_min=0., alpha_max=1.) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - - # symmetric - C1 = 0.5 * (C1 + C1.T) - C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) - - srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) - - G = log2['T'] - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - - np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) - - -def test_semirelaxed_gromov2_gradients(): - n_samples = 50 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) - - p = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - if torch: - - devices = [torch.device("cpu")] - if torch.cuda.is_available(): - devices.append(torch.device("cuda")) - for device in devices: - for loss_fun in ['square_loss', 'kl_loss']: - # semirelaxed solvers do not support gradients over masses yet. - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - - val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun) - - val.backward() - - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - - -def test_srgw_helper_backend(nx): - n_samples = 20 # nb samples - - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - for loss_fun in ['square_loss', 'kl_loss']: - C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True) - - # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun) - ones_pb = nx.ones(pb.shape[0], type_as=pb) - - def f(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def df(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) - # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) - - -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) -def test_gw_semirelaxed_helper_validation(loss_fun): - n_samples = 20 # nb samples - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - p = ot.unif(n_samples) - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun) - - -def test_semirelaxed_fgw(nx): - rng = np.random.RandomState(0) - list_n = [16, 8] - nt = 2 - ns = 24 - # create directed sbm with C2 as connectivity matrix - C1 = np.zeros((ns, ns)) - C2 = np.array([[0.7, 0.05], - [0.05, 0.9]]) - for i in range(nt): - for j in range(nt): - ni, nj = list_n[i], list_n[j] - xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij - F1 = np.zeros((ns, 1)) - F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) - F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) - F2 = np.zeros((2, 1)) - F2[1, :] = 1. - M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) - - p = ot.unif(ns) - q0 = ot.unif(C2.shape[0]) - G0 = p[:, None] * q0[None, :] - - # asymmetric - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - - # symmetric - for loss_fun in ['square_loss', 'kl_loss']: - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - - srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) - - -def test_semirelaxed_fgw2_gradients(): - n_samples = 20 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) - - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) - - p = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - M = ot.dist(xs, xt) - - C1 /= C1.max() - C2 /= C2.max() - - if torch: - - devices = [torch.device("cpu")] - if torch.cuda.is_available(): - devices.append(torch.device("cuda")) - for device in devices: - # semirelaxed solvers do not support gradients over masses yet. - for loss_fun in ['square_loss', 'kl_loss']: - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun) - - val.backward() - - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape - - # full gradients with alpha - p1 = torch.tensor(p, requires_grad=False, device=device) - C11 = torch.tensor(C1, requires_grad=True, device=device) - C12 = torch.tensor(C2, requires_grad=True, device=device) - M1 = torch.tensor(M, requires_grad=True, device=device) - alpha = torch.tensor(0.5, requires_grad=True, device=device) - - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha) - - val.backward() - - assert val.device == p1.device - assert p1.grad is None - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert alpha.shape == alpha.grad.shape - - -def test_srfgw_helper_backend(nx): - n_samples = 20 # nb samples - - mu = np.array([0, 0]) - cov = np.array([[1, 0], [0, 1]]) - - rng = np.random.RandomState(42) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) - ys = rng.randn(xs.shape[0], 2) - xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1) - yt = rng.randn(xt.shape[0], 2) - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - G0 = p[:, None] * q[None, :] - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - M /= M.max() - - Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - alpha = 0.5 - Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) - - # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') - ones_pb = nx.ones(pb.shape[0], type_as=pb) - - def f(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def df(G): - qG = nx.sum(G, 0) - marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) - return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - - def line_search(cost, G, deltaG, Mi, cost_G): - return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) - # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - # check constraints - np.testing.assert_allclose(res, Gb, atol=1e-06) - - -def test_entropic_semirelaxed_gromov(nx): - # unbalanced proportions - list_n = [30, 15] - nt = 2 - ns = np.sum(list_n) - # create directed sbm with C2 as connectivity matrix - C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) - rng = np.random.RandomState(0) - for i in range(nt): - for j in range(nt): - ni, nj = list_n[i], list_n[j] - xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij - p = ot.unif(ns, type_as=C1) - q0 = ot.unif(C2.shape[0], type_as=C1) - G0 = p[:, None] * q0[None, :] - # asymmetric - C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - epsilon = 0.1 - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - - # symmetric - C1 = 0.5 * (C1 + C1.T) - C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) - - srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) - - G = log2['T'] - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - - np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) - - -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_semirelaxed_gromov_dtype_device(nx): - # setup - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - p = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - for tp in nx.__type_list__: - - print(nx.dtype_device(tp)) - for loss_fun in ['square_loss', 'kl_loss']: - C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) - - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True - ) - gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True - ) - - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) - - -def test_entropic_semirelaxed_fgw(nx): - rng = np.random.RandomState(0) - list_n = [16, 8] - nt = 2 - ns = 24 - # create directed sbm with C2 as connectivity matrix - C1 = np.zeros((ns, ns)) - C2 = np.array([[0.7, 0.05], - [0.05, 0.9]]) - for i in range(nt): - for j in range(nt): - ni, nj = list_n[i], list_n[j] - xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij - F1 = np.zeros((ns, 1)) - F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) - F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) - F2 = np.zeros((2, 1)) - F2[1, :] = 1. - M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) - - p = ot.unif(ns) - q0 = ot.unif(C2.shape[0]) - G0 = p[:, None] * q0[None, :] - - # asymmetric - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - - # symmetric - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - - srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov - - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) - - -@pytest.skip_backend("tf", reason="test very slow with tf backend") -def test_entropic_semirelaxed_fgw_dtype_device(nx): - # setup - n_samples = 5 # nb samples - - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) - - xt = xs[::-1].copy() - - rng = np.random.RandomState(42) - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - - M = ot.dist(ys, yt) - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) - - for loss_fun in ['square_loss', 'kl_loss']: - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True - ) - fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True - ) - - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, fgw_valb) - - -def test_not_implemented_solver(): - # test sinkhorn - n_samples = 5 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) - - rng = np.random.RandomState(42) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) - xt = xs[::-1].copy() - ys = rng.randn(xs.shape[0], 2) - yt = ys[::-1].copy() - - p = ot.unif(n_samples) - q = ot.unif(n_samples) - - C1 = ot.dist(xs, xs) - C2 = ot.dist(xt, xt) - - C1 /= C1.max() - C2 /= C2.max() - M = ot.dist(ys, yt) - - solver = 'not_implemented' - # entropic gw and fgw - with pytest.raises(ValueError): - ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) - with pytest.raises(ValueError): - ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) From 81d76316eca975f0f9f1d4ade167021764405cef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 26 Apr 2024 08:22:18 +0200 Subject: [PATCH 12/30] [MRG] More general solvers for ``ot.solve`and examples of different variants. (#620) * add exaple and allow for functional regularizers * fix test since ow all is implemented * manuel regularizer available for exact and unbalanecd ot * exmaple with banaced manuel regularizer * upate documenation * pep8 * clenaup envelope instedaof implicit * big release file update --- RELEASES.md | 16 +++- examples/plot_solve_variants.py | 150 ++++++++++++++++++++++++++++++++ ot/__init__.py | 2 +- ot/solvers.py | 54 ++++++++---- ot/unbalanced.py | 36 +++++++- test/test_solvers.py | 33 +++++-- 6 files changed, 260 insertions(+), 31 deletions(-) create mode 100644 examples/plot_solve_variants.py diff --git a/RELEASES.md b/RELEASES.md index c7e3f598b..106042af2 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,12 +1,14 @@ # Releases -## 0.9.3dev +## 0.9.4dev #### New features -+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. ++ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607). ++ Continuous entropic mapping (PR #613) ++ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) ++ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605). #### Closed issues -- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) - Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) @@ -14,6 +16,14 @@ - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) +## 0.9.3 +*January 2024* + + +#### Closed issues +- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) + + ## 0.9.2 *December 2023* diff --git a/examples/plot_solve_variants.py b/examples/plot_solve_variants.py new file mode 100644 index 000000000..82f892a52 --- /dev/null +++ b/examples/plot_solve_variants.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +""" +====================================== +Optimal Transport solvers comparison +====================================== + +This example illustrates the solutions returns for diffrent variants of exact, +regularized and unbalanced OT solvers. +""" + +# Author: Remi Flamary +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +#%% + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 50 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std +b = gauss(n, m=25, s=5) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Define Group lasso regularization and gradient +# ------------------------------------------------ +# The groups are the first and second half of the columns of G + + +def reg_gl(G): # group lasso + small l2 reg + G1 = G[:n // 2, :]**2 + G2 = G[n // 2:, :]**2 + gl1 = np.sum(np.sqrt(np.sum(G1, 0))) + gl2 = np.sum(np.sqrt(np.sum(G2, 0))) + return gl1 + gl2 + 0.1 * np.sum(G**2) + + +def grad_gl(G): # gradient of group lasso + small l2 reg + G1 = G[:n // 2, :] + G2 = G[n // 2:, :] + gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8) + gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8) + return np.concatenate((gl1, gl2), axis=0) + 0.2 * G + + +reg_type_gl = (reg_gl, grad_gl) + +# %% +# Set up parameters for solvers and solve +# --------------------------------------- + +lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] +lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] + +lst_solvers = [ # name, param for ot.solve function + # balanced OT + ('Exact OT', dict()), + ('Entropic Reg. OT', dict(reg=0.005)), + ('L2 Reg OT', dict(reg=1, reg_type='l2')), + ('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)), + + + # unbalanced OT KL + ('Unbalanced KL No Reg.', dict(unbalanced=0.005)), + ('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')), + ('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')), + ('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')), + + # unbalanced OT L2 + ('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')), + ('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')), + ('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')), + ('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')), + + # unbalanced OT TV + ('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')), + ('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')), + ('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')), + ('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')), + +] + +lst_plans = [] +for (name, param) in lst_solvers: + G = ot.solve(M, a, b, **param).plan + lst_plans.append(G) + +############################################################################## +# Plot plans +# ---------- + +pl.figure(3, figsize=(9, 9)) + +for i, bname in enumerate(lst_unbalanced): + for j, rname in enumerate(lst_regs): + pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) + + plan = lst_plans[i * len(lst_regs) + j] + m2 = plan.sum(0) + m1 = plan.sum(1) + m1, m2 = m1 / a.max(), m2 / b.max() + pl.imshow(plan, cmap='Greys') + pl.plot(x, m2 * 10, 'r') + pl.plot(m1 * 10, x, 'b') + pl.plot(x, b / b.max() * 10, 'r', alpha=0.3) + pl.plot(a / a.max() * 10, x, 'b', alpha=0.3) + #pl.axis('off') + pl.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) + if i == 0: + pl.title(rname) + if j == 0: + pl.ylabel(bname, fontsize=14) diff --git a/ot/__init__.py b/ot/__init__.py index 1c10efafd..609f9ff37 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,7 +58,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.3dev" +__version__ = "0.9.4dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/solvers.py b/ot/solvers.py index de817d7f7..95165ea11 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -23,6 +23,7 @@ from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport from .lowrank import lowrank_sinkhorn +from .optim import cg lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] @@ -57,13 +58,15 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + Type of regularization :math:`R` either "KL", "L2", "entropy", + by default "KL". a tuple of functions can be provided for general + solver (see :any:`cg`). This is only used when ``reg!=None``. unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", - "TV", by default "KL" + "TV", by default "KL". method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. @@ -80,10 +83,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. - 'implicit' provides gradients only for `value` and and other outputs are + 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns @@ -140,13 +143,13 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - # Use implicit differentiation for memory saving - res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors + # Use envelope theorem differentiation for memory saving + res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `implicit` mode computes the implicit gradients only + `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. @@ -311,9 +314,22 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if unbalanced is None: # Balanced regularized OT - if reg_type.lower() in ['entropy', 'kl']: + if isinstance(reg_type, tuple): # general solver + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) + + value_linear = nx.sum(M * plan) + value = log['loss'][-1] + potentials = (log['u'], log['v']) + + elif reg_type.lower() in ['entropy', 'kl']: - if grad == 'implicit': # if implicit then detach the input + if grad == 'envelope': # if envelope then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -336,7 +352,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['log_u'], log['log_v']) - if grad == 'implicit': # set the gradient at convergence + if grad == 'envelope': # set the gradient at convergence value = nx.set_gradients(value, (M0, a0, b0), (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) @@ -359,7 +375,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, else: # unbalanced AND regularized OT - if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': + if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': if max_iter is None: max_iter = 1000 @@ -374,14 +390,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['logu'], log['logv']) - elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: + elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 + if isinstance(reg_type, str): + reg_type = reg_type.lower() - plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) value_linear = nx.sum(M * plan) @@ -962,10 +980,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. - 'implicit' provides gradients only for `value` and and other outputs are + 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns @@ -1034,13 +1052,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # lazy OT plan lazy_plan = res.lazy_plan - # Use implicit differentiation for memory saving - res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit') + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `implicit` mode computes the implicit gradients only + `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 73667b324..c39888a31 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -1432,6 +1432,9 @@ def grad_entropy(G): elif reg_div == 'entropy': reg_fun = reg_entropy grad_reg_fun = grad_entropy + elif isinstance(reg_div, tuple): + reg_fun = reg_div[0] + grad_reg_fun = reg_div[1] else: reg_fun = reg_l2 grad_reg_fun = grad_l2 @@ -1451,9 +1454,20 @@ def grad_marg_kl(G): return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) + def marg_tv(G): + return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ + reg_m2 * np.sum(np.abs(G.sum(0) - b)) + + def grad_marg_tv(G): + return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) + if regm_div == 'kl': regm_fun = marg_kl grad_regm_fun = grad_marg_kl + elif regm_div == 'tv': + regm_fun = marg_tv + grad_regm_fun = grad_marg_tv else: regm_fun = marg_l2 grad_regm_fun = grad_marg_l2 @@ -1518,7 +1532,10 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', reg_div: string, optional Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or - 'kl' (Kullback-Leibler) or 'l2' (quadratic). + 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + of two calable functions returning the reg term and its derivative. + Note that the callable functions should be able to handle numpy arrays + and not tesors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) @@ -1574,6 +1591,23 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + # wrap the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + try: + f0(G0) + df0(G0) + except BaseException: + warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") + + def f(x): + return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) + + def df(x): + return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + + reg_div = (f, df) + reg_m1, reg_m2 = get_parameter_pair(reg_m) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) diff --git a/test/test_solvers.py b/test/test_solvers.py index 168b111e4..16e6df295 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -14,8 +14,9 @@ from ot.bregman import geomloss from ot.backend import torch + lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2'] +lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] @@ -109,7 +110,7 @@ def test_solve(nx): @pytest.mark.skipif(not torch, reason="torch no installed") -def test_solve_implicit(): +def test_solve_envelope(): n_samples_s = 10 n_samples_t = 7 @@ -126,7 +127,7 @@ def test_solve_implicit(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol0 = ot.solve(M, a, b, reg=10, grad='implicit') + sol0 = ot.solve(M, a, b, reg=10, grad='envelope') sol0.value.backward() gM0 = M.grad.clone() @@ -166,6 +167,15 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): try: + if reg_type == 'tuple': + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + # solve unif weights sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) @@ -176,9 +186,20 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): # solve in backend ab, bb, Mb = nx.from_numpy(a, b, M) - solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + if isinstance(reg_type, tuple): + def f(G): + return nx.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + + solb = ot.solve(Mb, ab, bb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) assert_allclose_sol(sol, solb) + except NotImplementedError: pytest.skip("Not implemented") @@ -201,10 +222,6 @@ def test_solve_not_implemented(nx): with pytest.raises(NotImplementedError): ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') - # pairs of incompatible divergences - with pytest.raises(NotImplementedError): - ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') - def test_solve_gromov(nx): From e01c4e6ac68f2ad476d29f236d9607ea5d058362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 27 May 2024 21:58:35 +0200 Subject: [PATCH 13/30] Update build_tests.yml (#621) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update build_tests.yml * Update build_tests.yml * Update build_tests.yml * Update build_tests.yml * Update build_tests.yml * Update build_tests.yml * Update build_tests.yml * update all versions of checkout and setup-python * Update build_wheels.yml * update versions for all builds * Update build_doc.yml * Update plot_UOT_barycenter_1D.py with deprecated version of pylab * Update plot_barycenter_1D.py --------- Co-authored-by: Cédric Vincent-Cuaz --- .github/workflows/build_doc.yml | 10 ++++---- .github/workflows/build_tests.yml | 24 +++++++++---------- .github/workflows/build_tests_cuda.yml | 2 +- .github/workflows/build_wheels.yml | 10 ++++---- .github/workflows/build_wheels_weekly.yml | 4 ++-- examples/barycenters/plot_barycenter_1D.py | 4 ++-- .../plot_UOT_barycenter_1D.py | 4 ++-- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index 93bd11333..3af2d301f 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -13,13 +13,13 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 # Standard drop-in approach that should work for most people. - - name: Set up Python 3.8 - uses: actions/setup-python@v1 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.10" - name: Get Python running run: | @@ -41,4 +41,4 @@ jobs: - uses: actions/upload-artifact@v1 with: name: Documentation - path: docs/build/html/ \ No newline at end of file + path: docs/build/html/ diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 392bce4ec..9bdd337c0 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -25,9 +25,9 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install POT @@ -48,9 +48,9 @@ jobs: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'no pep8')" steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies @@ -69,9 +69,9 @@ jobs: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'no ci')" steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies @@ -93,12 +93,12 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install POT @@ -120,12 +120,12 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: RC.exe diff --git a/.github/workflows/build_tests_cuda.yml b/.github/workflows/build_tests_cuda.yml index e614aefa8..be8e47c8b 100644 --- a/.github/workflows/build_tests_cuda.yml +++ b/.github/workflows/build_tests_cuda.yml @@ -12,7 +12,7 @@ jobs: if: github.event.review.state == 'approved' || github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && github.ref == 'refs/heads/master') steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Install POT run: | python3.10 -m pip install --ignore-installed -e . diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index c6c70251b..c60babff6 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -18,9 +18,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: "3.10" @@ -53,9 +53,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] steps: - - uses: actions/checkout@v1 - - name: Set up Python 3.8 - uses: actions/setup-python@v1 + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index 5245d41cc..6b2f124fa 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -17,9 +17,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 40dc44476..7c17c9b22 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -100,7 +100,7 @@ #%% plot interpolation plt.figure(2) -cmap = plt.cm.get_cmap('viridis') +cmap = plt.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): @@ -122,7 +122,7 @@ plt.tight_layout() plt.figure(3) -cmap = plt.cm.get_cmap('viridis') +cmap = plt.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py index f747055cb..de1a3b3d5 100644 --- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -120,7 +120,7 @@ pl.figure(3) -cmap = pl.cm.get_cmap('viridis') +cmap = pl.get_cmap('viridis') verts = [] zs = weight_list for i, z in enumerate(zs): @@ -142,7 +142,7 @@ pl.tight_layout() pl.figure(4) -cmap = pl.cm.get_cmap('viridis') +cmap = pl.get_cmap('viridis') verts = [] zs = weight_list for i, z in enumerate(zs): From 2472dd462a86153da5e3864be1c102ac4046455b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laur=C3=A8ne?= <92529964+laudavid@users.noreply.github.com> Date: Wed, 29 May 2024 15:22:39 +0200 Subject: [PATCH 14/30] Implementation of Low Rank Gromov-Wasserstein (#614) * new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * modify low rank, remove solve_sample,OTResultLazy * new file for lr sinkhorn * lr sinkhorn, solve_sample, OTResultLazy * add test functions + small modif lr_sin/solve_sample * add import to __init__ * remove test solve_sample * add value, value_linear, lazy_plan * add comments to lr algorithm * modify test functions + add comments to lowrank * modify __init__ with lowrank * debug lowrank + test * debug test function low_rank * error test * final debug of lowrank + add new test functions * Debug tests + add lowrank to solve_sample * fix torch backend for lowrank * fix jax backend and skip tf * fix pep 8 tests * add lowrank init + test functions * Add init strategies in lowrank + example (PythonOT#588) * modified lowrank * changes from code review * fix error test pep8 * fix linux-minimal-deps + code review * Implementation of LR GW + add method in __init__ * add LR gw paper in README.md * add tests for low rank GW * add examples for Low Rank GW * fix __init__ * change atol of lr backends * fix pep8 errors * modif for code review --- CONTRIBUTORS.md | 2 +- README.md | 2 + RELEASES.md | 1 + examples/others/plot_lowrank_GW.py | 173 ++++++++++++++++ ot/__init__.py | 7 +- ot/gromov/__init__.py | 4 +- ot/gromov/_lowrank.py | 313 +++++++++++++++++++++++++++++ test/gromov/test_lowrank.py | 125 ++++++++++++ 8 files changed, 622 insertions(+), 5 deletions(-) create mode 100644 examples/others/plot_lowrank_GW.py create mode 100644 ot/gromov/_lowrank.py create mode 100644 test/gromov/test_lowrank.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 89c5be433..c185e18a7 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -50,7 +50,7 @@ The contributors to this library are: * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) -* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn) +* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples) ## Acknowledgments diff --git a/README.md b/README.md index 88dce689a..f1149a008 100644 --- a/README.md +++ b/README.md @@ -357,3 +357,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). + +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 106042af2..c31081451 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ + Continuous entropic mapping (PR #613) + New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) + Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605). ++ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614) #### Closed issues - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) diff --git a/examples/others/plot_lowrank_GW.py b/examples/others/plot_lowrank_GW.py new file mode 100644 index 000000000..02fef6ded --- /dev/null +++ b/examples/others/plot_lowrank_GW.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Low rank Gromov-Wasterstein between samples +======================================== + +Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] +on two curves in 2D and 3D, both sampled with 200 points. + +The squared Euclidean distance is considered as the ground cost for both samples. + +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). +"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". +In International Conference on Machine Learning (ICML), 2022. +""" + +# Author: Laurène David +# +# License: MIT License +# +# sphinx_gallery_thumbnail_number = 3 + +#%% +import numpy as np +import matplotlib.pylab as pl +import ot.plot +import time + +############################################################################## +# Generate data +# ------------- + +#%% parameters +n_samples = 200 + +# Generate 2D and 3D curves +theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples) +z = np.linspace(1, 2, n_samples) +r = z**2 + 1 +x = r * np.sin(theta) +y = r * np.cos(theta) + +# Source and target distribution +X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1) +Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1) + + +############################################################################## +# Plot data +# ------------ + +#%% +# Plot the source and target samples +fig = pl.figure(1, figsize=(10, 4)) + +ax = fig.add_subplot(121) +ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6) +ax.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) +ax.set_title("2D curve (source)") + +ax2 = fig.add_subplot(122, projection="3d") +ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6) +ax2.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) +ax2.view_init(15, -50) +ax2.set_title("3D curve (target)") + +pl.tight_layout() +pl.show() + + +############################################################################## +# Entropic Gromov-Wasserstein +# ------------ + +#%% + +# Compute cost matrices +C1 = ot.dist(X, X, metric="sqeuclidean") +C2 = ot.dist(Y, Y, metric="sqeuclidean") + +# Scale cost matrices +r1 = C1.max() +r2 = C2.max() + +C1 = C1 / r1 +C2 = C2 / r2 + + +# Solve entropic gw +reg = 5 * 1e-3 + +start = time.time() +gw, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, tol=1e-3, epsilon=reg, + log=True, verbose=False) + +end = time.time() +time_entropic = end - start + +entropic_gw_loss = np.round(log['gw_dist'], 3) + +# Plot entropic gw +pl.figure(2) +pl.imshow(gw, interpolation="nearest", aspect="auto") +pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss)) +pl.show() + + +############################################################################## +# Low rank squared euclidean cost matrices +# ------------ +# %% + +# Compute the low rank sqeuclidean cost decompositions +A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) +B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False) + +# Scale the low rank cost matrices +A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1) +B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2) + + +############################################################################## +# Low rank Gromov-Wasserstein +# ------------ +# %% + +# Solve low rank gromov-wasserstein with different ranks +list_rank = [10, 50] +list_P_GW = [] +list_loss_GW = [] +list_time_GW = [] + +for rank in list_rank: + start = time.time() + + Q, R, g, log = ot.lowrank_gromov_wasserstein_samples( + X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2), + cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6, + ) + end = time.time() + + P = log["lazy_plan"][:] + loss = log["value"] + + list_P_GW.append(P) + list_loss_GW.append(np.round(loss, 3)) + list_time_GW.append(end - start) + + +# %% +# Plot low rank GW with different ranks +pl.figure(3, figsize=(10, 4)) + +pl.subplot(1, 2, 1) +pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto") +pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0])) + +pl.subplot(1, 2, 2) +pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto") +pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1])) + +pl.tight_layout() +pl.show() + + +# %% +# Compare computation time between entropic GW and low rank GW +print("Entropic GW: {:.2f}s".format(time_entropic)) +print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0])) +print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1])) diff --git a/ot/__init__.py b/ot/__init__.py index 609f9ff37..d8ac5ac28 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -49,7 +49,8 @@ from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, - gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2, + lowrank_gromov_wasserstein_samples) from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve, solve_gromov, solve_sample @@ -71,5 +72,5 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', - 'lowrank_sinkhorn'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn', + 'lowrank_gromov_wasserstein_samples'] diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 4d77fc57a..b33dafd32 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -47,6 +47,8 @@ fused_gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing) +from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples) + __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss', 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed', @@ -64,4 +66,4 @@ 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', - 'fused_gromov_wasserstein_linear_unmixing'] + 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples'] diff --git a/ot/gromov/_lowrank.py b/ot/gromov/_lowrank.py new file mode 100644 index 000000000..5bab15edc --- /dev/null +++ b/ot/gromov/_lowrank.py @@ -0,0 +1,313 @@ +""" +Low rank Gromov-Wasserstein solver +""" + +# Author: Laurène David +# +# License: MIT License + + +import warnings +from ..utils import unif, get_lowrank_lazytensor +from ..backend import get_backend +from ..lowrank import compute_lr_sqeuclidean_matrix, _init_lr_sinkhorn, _LR_Dysktra + + +def _flat_product_operator(X, nx=None): + r""" + Implementation of the flattened out-product operator. + + This function is used in low rank gromov wasserstein to compute the low rank decomposition of + a cost matrix's squared hadamard product (page 6 in paper). + + Parameters + ---------- + X: array-like, shape (n_samples, n_col) + Input matrix for operator + + nx: default None + POT backend + + Returns + ---------- + X_flat: array-like, shape (n_samples, n_col**2) + Matrix with flattened out-product operator applied on each row + + References + ---------- + .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). + "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". + In International Conference on Machine Learning (ICML), 2022. + + """ + + if nx is None: + nx = get_backend(X) + + n = X.shape[0] + x1 = X[0, :][:, None] + X_flat = nx.dot(x1, x1.T).flatten()[:, None] + + for i in range(1, n): + x = X[i, :][:, None] + x_out = nx.dot(x, x.T).flatten()[:, None] + X_flat = nx.concatenate((X_flat, x_out), axis=1) + + X_flat = X_flat.T + + return X_flat + + +def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, gamma_init="rescale", + rescale_cost=True, cost_factorized_Xs=None, cost_factorized_Xt=None, stopThr=1e-4, numItermax=1000, + stopThr_dykstra=1e-3, numItermax_dykstra=10000, seed_init=49, warn=True, warn_dykstra=False, log=False): + + r""" + Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints + on the couplings and cost matrices. + + Squared euclidean distance matrices are considered for the target and source distributions. + + The function solves the following optimization problem: + + .. math:: + \mathop{\min_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \mathcal{Q}_{A,B}(Q\mathrm{diag}(1/g)R^T) - + \epsilon \cdot H((Q,R,g)) + + where : + + - :math: `A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain. + - :math: `B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain. + - :math: `\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan. + - :math: `Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan. + - :math: `g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1). + - :math: `r` is the rank of the Gromov-Wasserstein plan. + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem. + - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim_Xs) + Samples in the source domain + X_t : array-like, shape (n_samples_b, dim_Xt) + Samples in the target domain + a : array-like, shape (n_samples_a,), optional + Samples weights in the source domain + If let to its default value None, uniform distribution is taken. + b : array-like, shape (n_samples_b,), optional + Samples weights in the target domain + If let to its default value None, uniform distribution is taken. + reg : float, optional + Regularization term >=0 + rank : int, optional. Default is None. (>0) + Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. + alpha : int, optional. Default is 1e-10. (>0 and <1/r) + Lower bound for the weight vector g. + rescale_cost : bool, optional. Default is False + Rescale the low rank factorization of the sqeuclidean cost matrix + seed_init : int, optional. Default is 49. (>0) + Random state for the 'random' initialization of low rank couplings + gamma_init : str, optional. Default is "rescale". + Initialization strategy for gamma. 'rescale', or 'theory' + Gamma is a constant that scales the convergence criterion of the Mirror Descent + optimization scheme used to compute the low-rank couplings (Q, R and g) + numItermax : int, optional. Default is 1000. + Max number of iterations for Low Rank GW + stopThr : float, optional. Default is 1e-4. + Stop threshold on error (>0) for Low Rank GW + The error is the sum of Kullback Divergences computed for each low rank + coupling (Q, R and g) and scaled using gamma. + numItermax_dykstra : int, optional. Default is 2000. + Max number of iterations for the Dykstra algorithm + stopThr_dykstra : float, optional. Default is 1e-7. + Stop threshold on error (>0) in Dykstra + cost_factorized_Xs: tuple, optional. Default is None + Tuple with two pre-computed low rank decompositions (A1, A2) of the source cost + matrix. Both matrices should have a shape of (n_samples_a, dim_Xs + 2). + If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. + cost_factorized_Xt: tuple, optional. Default is None + Tuple with two pre-computed low rank decompositions (B1, B2) of the target cost + matrix. Both matrices should have a shape of (n_samples_b, dim_Xt + 2). + If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. + warn : bool, optional + if True, raises a warning if the low rank GW algorithm doesn't convergence. + warn_dykstra: bool, optional + if True, raises a warning if the Dykstra algorithm doesn't convergence. + log : bool, optional + record log if True + + + Returns + --------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + Weight vector for the low-rank decomposition of the OT + log : dict (lazy_plan, value and value_linear) + log dictionary return only if log==True in parameters + + + References + ---------- + .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). + "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". + In International Conference on Machine Learning (ICML), 2022. + + """ + + # POT backend + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] + + # Initialize weights a, b + if a is None: + a = unif(ns, type_as=X_s) + if b is None: + b = unif(nt, type_as=X_t) + + # Compute rank (see Section 3.1, def 1) + r = rank + if rank is None: + r = min(ns, nt) + else: + r = min(ns, nt, rank) + + if r <= 0: + raise ValueError("The rank parameter cannot have a negative value") + + # Dykstra won't converge if 1/rank < alpha (see Section 3.2) + if 1 / r < alpha: + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank)) + + if cost_factorized_Xs is not None: + A1, A2 = cost_factorized_Xs + else: + A1, A2 = compute_lr_sqeuclidean_matrix(X_s, X_s, rescale_cost, nx=nx) + + if cost_factorized_Xt is not None: + B1, B2 = cost_factorized_Xt + else: + B1, B2 = compute_lr_sqeuclidean_matrix(X_t, X_t, rescale_cost, nx=nx) + + # Initial values for LR couplings (Q, R, g) with LOT + Q, R, g = _init_lr_sinkhorn( + X_s, X_t, a, b, r, init="random", random_state=seed_init, reg_init=None, nx=nx + ) + + # Gamma initialization + if gamma_init == "theory": + L = (27 * nx.norm(A1) * nx.norm(A2)) / alpha**4 + gamma = 1 / (2 * L) + + if gamma_init not in ["rescale", "theory"]: + raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))) + + # initial value of error + err = 1 + + for ii in range(numItermax): + Q_prev = Q + R_prev = R + g_prev = g + + if err > stopThr: + # Compute cost matrices + C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) + C1 = - 4 * nx.dot(A1, C1) + C2 = nx.dot(R.T, B1) + C2 = nx.dot(C2, B2.T) + diag_g = (1 / g)[None, :] + + # Compute C*R dot using the lr decomposition of C + CR = nx.dot(C2, R) + CR = nx.dot(C1, CR) + CR_g = CR * diag_g + + # Compute C.T * Q using the lr decomposition of C + CQ = nx.dot(C1.T, Q) + CQ = nx.dot(C2.T, CQ) + CQ_g = CQ * diag_g + + # Compute omega + omega = nx.diag(nx.dot(Q.T, CR)) + + # Rescale gamma at each iteration + if gamma_init == "rescale": + norm_1 = nx.max(nx.abs(CR_g + reg * nx.log(Q))) ** 2 + norm_2 = nx.max(nx.abs(CQ_g + reg * nx.log(R))) ** 2 + norm_3 = nx.max(nx.abs(-omega * (diag_g**2))) ** 2 + gamma = 10 / max(norm_1, norm_2, norm_3) + + K1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q)) + K2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R)) + K3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g)) + + # Update couplings with LR Dykstra algorithm + Q, R, g = _LR_Dysktra( + K1, K2, K3, a, b, alpha, stopThr_dykstra, numItermax_dykstra, warn_dykstra, nx + ) + + # Update error with kullback-divergence + err_1 = ((1 / gamma) ** 2) * (nx.kl_div(Q, Q_prev) + nx.kl_div(Q_prev, Q)) + err_2 = ((1 / gamma) ** 2) * (nx.kl_div(R, R_prev) + nx.kl_div(R_prev, R)) + err_3 = ((1 / gamma) ** 2) * (nx.kl_div(g, g_prev) + nx.kl_div(g_prev, g)) + err = err_1 + err_2 + err_3 + + # fix divide by zero + Q = Q + 1e-16 + R = R + 1e-16 + g = g + 1e-16 + + else: + break + + else: + if warn: + warnings.warn( + "Low Rank GW did not converge. You might want to " + "increase the number of iterations `numItermax` " + ) + + # Update low rank costs + C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) + C1 = - 4 * nx.dot(A1, C1) + C2 = nx.dot(R.T, B1) + C2 = nx.dot(C2, B2.T) + + # Compute lazy plan (using LazyTensor class) + lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) + + # Compute value_quad + A1_, A2_ = _flat_product_operator(A1, nx), _flat_product_operator(A2, nx) + B1_, B2_ = _flat_product_operator(B1, nx), _flat_product_operator(B2, nx) + + x_ = nx.dot(A1_, nx.dot(A2_.T, a)) + y_ = nx.dot(B1_, nx.dot(B2_.T, b)) + c1 = nx.dot(x_, a) + nx.dot(y_, b) + + G = nx.dot(C1, nx.dot(C2, R)) + G = nx.dot(Q.T, G * diag_g) + value_quad = c1 + nx.trace(G) / 2 + + if reg != 0: + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R + value = value_quad + reg * (reg_Q + reg_g + reg_R) + else: + value = value_quad + + if log: + dict_log = dict() + dict_log["value"] = value + dict_log["value_quad"] = value_quad + dict_log["lazy_plan"] = lazy_plan + + return Q, R, g, dict_log + + return Q, R, g diff --git a/test/gromov/test_lowrank.py b/test/gromov/test_lowrank.py new file mode 100644 index 000000000..befc5c835 --- /dev/null +++ b/test/gromov/test_lowrank.py @@ -0,0 +1,125 @@ +""" Tests for gromov._lowrank.py """ + +# Author: Laurène DAVID +# +# License: MIT License + +import ot +import numpy as np +import pytest + + +def test__flat_product_operator(): + # test flat product operator + n, d = 100, 2 + X = np.reshape(1.0 * np.arange(2 * n), (n, d)) + A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) + + A1_ = ot.gromov._flat_product_operator(A1) + A2_ = ot.gromov._flat_product_operator(A2) + cost = ot.dist(X, X) + + # test value + np.testing.assert_allclose(cost**2, np.dot(A1_, A2_.T), atol=1e-05) + + +def test_lowrank_gromov_wasserstein_samples(): + # test low rank gromov wasserstein + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-04) + np.testing.assert_allclose(b, P.sum(0), atol=1e-04) + + # check if lazy_plan is equal to the fully computed plan + P_true = np.dot(Q, np.dot(np.diag(1 / g), R.T)) + np.testing.assert_allclose(P, P_true, atol=1e-05) + + # check warn parameter when low rank GW algorithm doesn't converge + with pytest.warns(UserWarning): + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1, warn=True, warn_dykstra=False + ) + + # check warn parameter when Dykstra algorithm doesn't converge + with pytest.warns(UserWarning): + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, stopThr_dykstra=0, numItermax_dykstra=1, warn=False, warn_dykstra=True + ) + + +@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6), (0.1, -1))) +def test_lowrank_gromov_wasserstein_samples_alpha_error(alpha, rank): + # Test warning for value of alpha and rank + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + with pytest.raises(ValueError): + ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + + +@pytest.mark.parametrize(("gamma_init"), ("rescale", "theory", "other")) +def test_lowrank_wasserstein_samples_gamma_init(gamma_init): + # Test lr sinkhorn with different init strategies + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + if gamma_init not in ["rescale", "theory"]: + with pytest.raises(NotImplementedError): + ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + + else: + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-04) + np.testing.assert_allclose(b, P.sum(0), atol=1e-04) + + +@pytest.skip_backend('tf') +def test_lowrank_gromov_wasserstein_samples_backends(nx): + # Test low rank sinkhorn for different backends + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_sb, X_tb, ab, bb, reg=0.1, log=True) + lazy_plan = log["lazy_plan"] + P = lazy_plan[:] + + np.testing.assert_allclose(ab, P.sum(1), atol=1e-04) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-04) From cba9c7b7d27b59edf49979c746e480dbce787bc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 30 May 2024 01:09:22 +0200 Subject: [PATCH 15/30] [WIP] quantized gromov wasserstein solver (#603) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first commit : quantized gromov wasserstein solver * start setting up tests * fix build OT for all backends - nb: concatenation procedure is less efficient for numpy and torch * dealing with edge cases * fix pep8 * updates + start setting exemple * updates + start setting exemple * updating code + exemple + test + docs * fix sklearn imports * fix * setting up new API for qGW * fix pep8 * tests * update qFGW plots * update qFGW plots * up tests * update example * merge master * complete tests --------- Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 2 +- README.md | 5 +- RELEASES.md | 1 + .../plot_quantized_gromov_wasserstein.py | 515 ++++++++ ot/gromov/__init__.py | 16 +- ot/gromov/_quantized.py | 1147 +++++++++++++++++ test/gromov/test_quantized.py | 377 ++++++ 7 files changed, 2060 insertions(+), 3 deletions(-) create mode 100644 examples/gromov/plot_quantized_gromov_wasserstein.py create mode 100644 ot/gromov/_quantized.py create mode 100644 test/gromov/test_quantized.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index c185e18a7..e982cd5b6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,7 +41,7 @@ The contributors to this library are: * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) diff --git a/README.md b/README.md index f1149a008..1cd9fb59b 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ POT provides the following generic OT solvers (links to examples): * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]). +* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59]. @@ -358,4 +359,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). -[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. \ No newline at end of file +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. + +[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. diff --git a/RELEASES.md b/RELEASES.md index c31081451..51075c973 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,6 +3,7 @@ ## 0.9.4dev #### New features ++ New quantized FGW solvers `ot.gromov.quantized_fused_gromov_wasserstein`, `ot.gromov.quantized_fused_gromov_wasserstein_samples` and `ot.gromov.quantized_fused_gromov_wasserstein_partitioned` (PR #603) + `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607). + Continuous entropic mapping (PR #613) + New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) diff --git a/examples/gromov/plot_quantized_gromov_wasserstein.py b/examples/gromov/plot_quantized_gromov_wasserstein.py new file mode 100644 index 000000000..02d777c71 --- /dev/null +++ b/examples/gromov/plot_quantized_gromov_wasserstein.py @@ -0,0 +1,515 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +Quantized Fused Gromov-Wasserstein examples +=============================================== + +These examples show how to use the quantized (Fused) Gromov-Wasserstein +solvers (qFGW) [68]. POT provides a generic solver `quantized_fused_gromov_wasserstein_partitioned` +that takes as inputs partitioned graphs potentially endowed with node features, +which have to be built by the user. On top of that, POT provides two wrappers: + i) `quantized_fused_gromov_wasserstein` operating over generic graphs, whose + partitioning is performed via `get_graph_partition` using e.g the Louvain algorithm, + and representant for each partition can be selected via `get_graph_representants` + using e.g the PageRank algorithm. + + ii) `quantized_fused_gromov_wasserstein_samples` operating over point clouds, + e.g :math:`X_1 \in R^{n_1 * d_1}` and :math:`X_2 \in R^{n_2 * d_2}` + endowed with their respective euclidean geometry, whose partitioning and + representant selection is performed jointly using e.g the K-means algorithm + via the function `get_partition_and_representants_samples`. + + +We illustrate next how to compute the qGW distance on both types of data by: + + i) Generating two graphs following Stochastic Block Models encoded as shortest + path matrices as qGW solvers tends to require dense structure to achieve a good + approximation of the GW distance (as qGW is an upper-bound of GW). In the meantime, + we illustrate an optional feature of our solvers, namely the use of auxiliary + structures e.g adjacency matrices to perform the graph partitioning. + + ii) Generating two point clouds representing curves in 2D and 3D respectively. + We augment these point clouds by considering additional features of the same + dimensionaly :math:`F_1 \in R^{n_1 * d}` and :math:`F_2 \in R^{n_2 * d}`, + representing the color intensity associated to each sample of both distributions. + Then we compute the qFGW distance between these attributed point clouds. + + +[68] Chowdhury, S., Miller, D., & Needham, T. (2021). Quantized gromov-wasserstein. +ECML PKDD 2021. Springer International Publishing. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import matplotlib.pyplot as plt +import networkx +from networkx.generators.community import stochastic_block_model as sbm +from scipy.sparse.csgraph import shortest_path + +from ot.gromov import ( + quantized_fused_gromov_wasserstein_partitioned, quantized_fused_gromov_wasserstein, + get_graph_partition, get_graph_representants, format_partitioned_graph, + quantized_fused_gromov_wasserstein_samples, + get_partition_and_representants_samples) + +############################################################################# +# +# Generate graphs +# -------------------------------------------------------------------------- +# +# Create two graphs following Stochastic Block models of 2 and 3 clusters. + +N1 = 30 # 2 communities +N2 = 45 # 3 communities +p1 = [[0.8, 0.1], + [0.1, 0.7]] +p2 = [[0.8, 0.1, 0.], + [0.1, 0.75, 0.1], + [0., 0.1, 0.7]] +G1 = sbm(seed=0, sizes=[N1 // 2, N1 // 2], p=p1) +G2 = sbm(seed=0, sizes=[N2 // 3, N2 // 3, N2 // 3], p=p2) + + +C1 = networkx.to_numpy_array(G1) +C2 = networkx.to_numpy_array(G2) + +spC1 = shortest_path(C1) +spC2 = shortest_path(C2) + +h1 = np.ones(C1.shape[0]) / C1.shape[0] +h2 = np.ones(C2.shape[0]) / C2.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G1 = 5 +weight_inter_G1 = 0.5 +weight_intra_G2 = 1. +weight_inter_G2 = 1.5 + +weightedG1 = networkx.Graph() +part_G1 = [G1.nodes[i]['block'] for i in range(N1)] + +for node in G1.nodes(): + weightedG1.add_node(node) +for i, j in G1.edges(): + if part_G1[i] == part_G1[j]: + weightedG1.add_edge(i, j, weight=weight_intra_G1) + else: + weightedG1.add_edge(i, j, weight=weight_inter_G1) + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + + +# setup for graph visualization + +def node_coloring(part, starting_color=0): + + # get graphs partition and their coloring + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part)] + nodes_color_part = [] + for cluster in part: + nodes_color_part.append(unique_colors[cluster]) + + return nodes_color_part + + +def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, + edge_color='black', alpha_edge=0.7, node_size=None, + shiftx=0, seed=0, highlight_rep=False): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + width_edge = 1.5 + + if not highlight_rep: + networkx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + for edge in G.edges: + if (edge[0] in rep_indices) and (edge[1] in rep_indices): + networkx.draw_networkx_edges( + G, pos, edgelist=[edge], width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + else: + networkx.draw_networkx_edges( + G, pos, edgelist=[edge], width=width_edge, alpha=0.2, + edge_color=edge_color) + + for node, node_color in enumerate(nodes_color_part): + local_node_shape, local_node_size = 'o', node_size + + if highlight_rep: + if node in rep_indices: + local_node_shape, local_node_size = '*', 6 * node_size + + if node_alphas is None: + alpha = 0.9 + if highlight_rep: + alpha = 0.9 if node in rep_indices else 0.1 + + else: + alpha = node_alphas[node] + + networkx.draw_networkx_nodes(G, pos, nodelist=[node], alpha=alpha, + node_shape=local_node_shape, + node_size=local_node_size, + node_color=node_color) + + return pos + + +############################################################################# +# +# Compute their quantized Gromov-Wasserstein distance without using the wrapper +# --------------------------------------------------------- +# +# We detail next the steps implemented within the wrapper that preprocess graphs +# to form partitioned graphs, which are then passed as input to the generic qFGW solver. + +# 1-a) Partition C1 and C2 in 2 and 3 clusters respectively using Louvain +# algorithm from Networkx. Then encode these partitions via vectors of assignments. + +part_method = 'louvain' +rep_method = 'pagerank' + +npart_1 = 2 # 2 clusters used to describe C1 +npart_2 = 3 # 3 clusters used to describe C2 + +part1 = get_graph_partition( + C1, npart=npart_1, part_method=part_method, F=None, alpha=1.) +part2 = get_graph_partition( + C2, npart=npart_2, part_method=part_method, F=None, alpha=1.) + +# 1-b) Select representant in each partition using the Pagerank algorithm +# implementation from networkx. + +rep_indices1 = get_graph_representants(C1, part1, rep_method=rep_method) +rep_indices2 = get_graph_representants(C2, part2, rep_method=rep_method) + +# 1-c) Formate partitions such that: +# CR contains relations between representants in each space. +# list_R contains relations between samples and representants within each partition. +# list_h contains samples relative importance within each partition. + +CR1, list_R1, list_h1 = format_partitioned_graph( + spC1, h1, part1, rep_indices1, F=None, M=None, alpha=1.) + +CR2, list_R2, list_h2 = format_partitioned_graph( + spC2, h2, part2, rep_indices2, F=None, M=None, alpha=1.) + +# 1-d) call to partitioned quantized gromov-wasserstein solver + +OT_global_, OTs_local_, OT_, log_ = quantized_fused_gromov_wasserstein_partitioned( + CR1, CR2, list_R1, list_R2, list_h1, list_h2, MR=None, + alpha=1., build_OT=True, log=True) + + +# Visualization of the graph pre-processing + +node_size = 40 +fontsize = 10 +seed_G1 = 0 +seed_G2 = 3 + +part1_ = part1.astype(np.int32) +part2_ = part2.astype(np.int32) + + +nodes_color_part1 = node_coloring(part1_, starting_color=0) +nodes_color_part2 = node_coloring(part2_, starting_color=np.unique(nodes_color_part1).shape[0]) + + +pl.figure(1, figsize=(6, 5)) +pl.clf() +pl.axis('off') +pl.subplot(2, 3, 1) +pl.title(r'Input graph: $\mathbf{spC_1}$', fontsize=fontsize) + +pos1 = draw_graph( + G1, C1, ['C0' for _ in part1_], rep_indices1, node_size=node_size, seed=seed_G1) + +pl.subplot(2, 3, 2) +pl.title('Partitioning', fontsize=fontsize) + +_ = draw_graph( + G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, seed=seed_G1) + +pl.subplot(2, 3, 3) +pl.title('Representant selection', fontsize=fontsize) + +_ = draw_graph( + G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, + seed=seed_G1, highlight_rep=True) + +pl.subplot(2, 3, 4) +pl.title(r'Input graph: $\mathbf{spC_2}$', fontsize=fontsize) + +pos2 = draw_graph( + G2, C2, ['C0' for _ in part2_], rep_indices2, node_size=node_size, seed=seed_G2) + +pl.subplot(2, 3, 5) +pl.title(r'Partitioning', fontsize=fontsize) + +_ = draw_graph( + G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, seed=seed_G2) + +pl.subplot(2, 3, 6) +pl.title(r'Representant selection', fontsize=fontsize) + +_ = draw_graph( + G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, + seed=seed_G2, highlight_rep=True) +pl.tight_layout() + +############################################################################# +# +# Compute the quantized Gromov-Wasserstein distance using the wrapper +# --------------------------------------------------------- +# +# Compute qGW(spC1, h1, spC2, h2). We also illustrate the use of auxiliary matrices +# such that the adjacency matrices `C1_aux=C1` and `C2_aux=C2` to partition the graph using +# Louvain algorithm, and the Pagerank algorithm for selecting representant within +# each partition. Notice that `C1_aux` and `C2_aux` are optional, if they are not +# specified these pre-processing algorithms will be applied to spC2 and spC3. + + +# no node features are considered on this synthetic dataset. Hence we simply +# let F1, F2 = None and set alpha = 1. +OT_global, OTs_local, OT, log = quantized_fused_gromov_wasserstein( + spC1, spC2, npart_1, npart_2, h1, h2, C1_aux=C1, C2_aux=C2, F1=None, F2=None, + alpha=1., part_method=part_method, rep_method=rep_method, log=True) + +qGW_dist = log['qFGW_dist'] + + +############################################################################# +# +# Visualization of the quantized Gromov-Wasserstein matching +# -------------------------------------------------------------- +# +# We color nodes of the graph based on the respective partition of each graph. +# On the first plot we illustrate the qGW matching between both shortest path matrices. +# While the GW matching across representants of each space is illustrated on the right. + + +def draw_transp_colored_qGW( + G1, C1, G2, C2, part1, part2, rep_indices1, rep_indices2, T, + pos1=None, pos2=None, shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0, highlight_rep=False): + starting_color = 0 + # get graphs partition and their coloring + unique_colors1 = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors1[cluster]) + + starting_color = len(unique_colors1) + 1 + unique_colors2 = ['C%s' % (starting_color + i) for i in np.unique(part2)] + nodes_color_part2 = [] + for cluster in part2: + nodes_color_part2.append(unique_colors2[cluster]) + + pos1 = draw_graph( + G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, + shiftx=0, seed=seed_G1, highlight_rep=highlight_rep) + pos2 = draw_graph( + G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, + shiftx=shiftx, seed=seed_G1, highlight_rep=highlight_rep) + + if not highlight_rep: + for k1, v1 in pos1.items(): + max_Tk1 = np.max(T[k1, :]) + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.7, alpha=T[k1, k2] / max_Tk1, + color=nodes_color_part1[k1]) + + else: # OT is only between representants + for id1, node_id1 in enumerate(rep_indices1): + max_Tk1 = np.max(T[id1, :]) + for id2, node_id2 in enumerate(rep_indices2): + if (T[id1, id2] > 0): + pl.plot([pos1[node_id1][0], pos2[node_id2][0]], + [pos1[node_id1][1], pos2[node_id2][1]], + '-', lw=0.8, alpha=T[id1, id2] / max_Tk1, + color=nodes_color_part1[node_id1]) + return pos1, pos2 + + +pl.figure(2, figsize=(5, 2.5)) +pl.clf() +pl.axis('off') +pl.subplot(1, 2, 1) +pl.title(r'qGW$(\mathbf{spC_1}, \mathbf{spC_1}) =%s$' % (np.round(qGW_dist, 3)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_qGW( + weightedG1, C1, weightedG2, C2, part1_, part2_, rep_indices1, rep_indices2, + T=OT_, shiftx=1.5, node_size=node_size, seed_G1=seed_G1, seed_G2=seed_G2) + +pl.tight_layout() + +pl.subplot(1, 2, 2) +pl.title(r' GW$(\mathbf{CR_1}, \mathbf{CR_2}) =%s$' % (np.round(log_['global dist'], 3)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_qGW( + weightedG1, C1, weightedG2, C2, part1_, part2_, rep_indices1, rep_indices2, + T=OT_global, shiftx=1.5, node_size=node_size, seed_G1=seed_G1, seed_G2=seed_G2, + highlight_rep=True) + +pl.tight_layout() +pl.show() + +############################################################################# +# +# Generate attributed point clouds +# -------------------------------------------------------------------------- +# +# Create two attributed point clouds representing curves in 2D and 3D respectively, +# whose samples are further associated to various color intensities. + +n_samples = 100 + +# Generate 2D and 3D curves +theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples) +z = np.linspace(1, 2, n_samples) +r = z**2 + 1 +x = r * np.sin(theta) +y = r * np.cos(theta) + +# Source and target distribution across spaces encoded respectively via their +# squared euclidean distance matrices. + +X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1) +Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1) + +# Further associated to color intensity features derived from z + +FX = z - z.min() / (z.max() - z.min()) +FX = np.clip(0.8 * FX + 0.2, a_min=0.2, a_max=1.) # for numerical issues +FY = FX + + +############################################################################# +# +# Visualize partitioned attributed point clouds +# -------------------------------------------------------------------------- +# +# Compute the partitioning and representant selection further used within +# qFGW wrapper, both provided by a K-means algorithm. Then visualize partitioned spaces. + +part1, rep_indices1 = get_partition_and_representants_samples( + X, 4, 'kmeans', 0) +part2, rep_indices2 = get_partition_and_representants_samples( + Y, 4, 'kmeans', 0) + +upart1 = np.unique(part1) +upart2 = np.unique(part2) + +# Plot the source and target samples as distributions +s = 20 +fig = plt.figure(3, figsize=(6, 3)) + +ax1 = fig.add_subplot(1, 3, 1) +ax1.set_title("2D curve") +ax1.scatter(X[:, 0], X[:, 1], color="C0", alpha=FX, s=s) +plt.axis('off') + + +ax2 = fig.add_subplot(1, 3, 2) +ax2.set_title("Partitioning") +for i, elem in enumerate(upart1): + idx = np.argwhere(part1 == elem)[:, 0] + ax2.scatter(X[idx, 0], X[idx, 1], color="C%s" % i, alpha=FX[idx], s=s) +plt.axis('off') + +ax3 = fig.add_subplot(1, 3, 3) +ax3.set_title("Representant selection") +for i, elem in enumerate(upart1): + idx = np.argwhere(part1 == elem)[:, 0] + ax3.scatter(X[idx, 0], X[idx, 1], color="C%s" % i, alpha=FX[idx], s=10) + rep_idx = rep_indices1[i] + ax3.scatter([X[rep_idx, 0]], [X[rep_idx, 1]], color="C%s" % i, alpha=1, s=6 * s, marker='*') +plt.axis('off') +plt.tight_layout() +plt.show() + +start_color = upart1.shape[0] + 1 + +fig = plt.figure(4, figsize=(6, 5)) + +ax4 = fig.add_subplot(1, 3, 1, projection="3d") +ax4.set_title("3D curve") +ax4.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c='C0', alpha=FY, s=s) +plt.axis('off') + +ax5 = fig.add_subplot(1, 3, 2, projection="3d") +ax5.set_title("Partitioning") +for i, elem in enumerate(upart2): + idx = np.argwhere(part2 == elem)[:, 0] + color = 'C%s' % (start_color + i) + ax5.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s) +plt.axis('off') + +ax6 = fig.add_subplot(1, 3, 3, projection="3d") +ax6.set_title("Representant selection") +for i, elem in enumerate(upart2): + idx = np.argwhere(part2 == elem)[:, 0] + color = 'C%s' % (start_color + i) + rep_idx = rep_indices2[i] + ax6.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s) + ax6.scatter([Y[rep_idx, 0]], [Y[rep_idx, 1]], [Y[rep_idx, 2]], c=color, alpha=1, s=6 * s, marker='*') +plt.axis('off') +plt.tight_layout() +plt.show() + +############################################################################# +# +# Compute the quantized Fused Gromov-Wasserstein distance between samples using the wrapper +# --------------------------------------------------------- +# +# Compute qFGW(X, FX, hX, Y, FY, HY), setting the trade-off parameter between +# structures and features `alpha=0.5`. This solver considers a squared euclidean structure +# for each distribution X and Y, and partition each of them into 4 clusters using +# the K-means algorithm before computing qFGW. + +T_global, Ts_local, T, log = quantized_fused_gromov_wasserstein_samples( + X, Y, 4, 4, p=None, q=None, F1=FX[:, None], F2=FY[:, None], alpha=0.5, + method='kmeans', log=True) + +# Plot low rank GW with different ranks +pl.figure(5, figsize=(6, 3)) +pl.subplot(1, 2, 1) +pl.title('OT between distributions') +pl.imshow(T, interpolation="nearest", aspect="auto") +pl.colorbar() +pl.axis('off') + +pl.subplot(1, 2, 2) +pl.title('OT between representants') +pl.imshow(T_global, interpolation="nearest", aspect="auto") +pl.axis('off') +pl.colorbar() + +pl.tight_layout() +pl.show() diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index b33dafd32..03663dab4 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -50,6 +50,16 @@ from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples) +from ._quantized import (quantized_fused_gromov_wasserstein_partitioned, + get_graph_partition, + get_graph_representants, + format_partitioned_graph, + quantized_fused_gromov_wasserstein, + get_partition_and_representants_samples, + format_partitioned_samples, + quantized_fused_gromov_wasserstein_samples + ) + __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss', 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', @@ -66,4 +76,8 @@ 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', - 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples'] + 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples', + 'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition', + 'get_graph_representants', 'format_partitioned_graph', + 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', + 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples'] diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py new file mode 100644 index 000000000..147f4b221 --- /dev/null +++ b/ot/gromov/_quantized.py @@ -0,0 +1,1147 @@ +""" +Quantized (Fused) Gromov-Wasserstein solvers. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import warnings + +try: + from networkx.algorithms.community import asyn_fluidc, louvain_communities + from networkx import from_numpy_array, pagerank + networkx_import = True +except ImportError: + networkx_import = False + +try: + from sklearn.cluster import SpectralClustering, KMeans + sklearn_import = True +except ImportError: + sklearn_import = False + +import random + +from ..utils import list_to_array, unif, dist +from ..backend import get_backend +from ..lp import emd_1d +from ._gw import gromov_wasserstein, fused_gromov_wasserstein +from ._utils import init_matrix, gwloss + + +def quantized_fused_gromov_wasserstein_partitioned( + CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR=None, + alpha=1., build_OT=False, log=False, armijo=False, max_iter=1e4, + tol_rel=1e-9, tol_abs=1e-9, nx=None, **kwargs): + r""" + Returns the quantized Fused Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, + \mathbf{F_2}, \mathbf{q})`, whose samples are assigned to partitions and representants + :math:`\mathcal{P_1} = \{(\mathbf{P_{1, i}}, \mathbf{r_{1, i}})\}_{i \leq npart1}` + and :math:`\mathcal{P_2} = \{(\mathbf{P_{2, j}}, \mathbf{r_{2, j}})\}_{j \leq npart2}`. + The latter must be precomputed and encoded e.g for the source as: :math:`\mathbf{CR_1}` + structure matrix between representants; `list_R1` a list of relations between + representants and their associated samples; `list_p1` a list of nodes + distribution within each partition; :math:`\mathbf{FR_1}` feature matrix + of representants. + + The function estimates the following optimization problem: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + (1-\alpha) \langle \mathbf{T}, M\rangle_F + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{T}_{|\mathbf{P_{1, i}}, \mathbf{P_{2, j}}} &= T^{g}_{ij} \mathbf{T}^{(i,j)} + + using a two-step strategy computing: i) a global alignment :math:`\mathbf{T}^{g}` + between representants joint structure and feature spaces; ii) local alignments + :math:`\mathbf{T}^{(i, j)}` between partitions :math:`\mathbf{P_{1, i}}` + and :math:`\mathbf{P_{2, j}}` seen as 1D measures. + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{F_1}`: Feature matrix in the source space + - :math:`\mathbf{F_2}`: Feature matrix in the target space + - :math:`\mathbf{M}`: Pairwise similarity matrix between features + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - :math:`L`: quadratic loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the Gromov-Wasserstein conjugate gradient solver + are done with numpy to limit memory overhead. + + Parameters + ---------- + CR1 : array-like, shape (npart1, npart1) + Structure matrix between partition representants in the source space. + CR2 : array-like, shape (npart2, npart2) + Structure matrix between partition representants in the target space. + list_R1 : list of npart1 arrays, + List of relations between representants and their associated samples in the source space. + list_R2 : list of npart2 arrays, + List of relations between representants and their associated samples in the target space. + list_p1 : list of npart1 arrays, + List of node distributions within each partition of the source space. + list_p : list of npart2 arrays, + List of node distributions within each partition of the target space. + MR : array-like, shape (npart1, npart2), optional. (Default is None) + Metric cost matrix between features of representants across spaces. + alpha: float, optional. Default is None. + FGW trade-off parameter in :math:`]0, 1]` between structure and features. + If `alpha = 1` features are ignored hence computing qGW. + build_OT: bool, optional. Default is False + Either to build or not the OT between non-partitioned structures. + log : bool, optional. Default is False + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + nx : backend, optional + POT backend + + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T_global: array-like, shape (`npart1`, `npart2`) + Gromov-Wasserstein alignment :math:`\mathbf{T}^{g}` between representants. + Ts_local: dict of local OT matrices. + Dictionary with keys :math:`(i, j)` corresponding to 1D OT between + :math:`\mathbf{P_{1, i}}` and :math:`\mathbf{P_{2, j}}` if :math:`T^{g}_{ij} \neq 0`. + T: array-like, shape `(ns, nt)` + Coupling between the two spaces if `build_OT=True` else None. + log : dict, if `log=True`. + Convergence information and losses of inner OT problems. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + arr = [CR1, CR2, *list_R1, *list_R2, *list_p1, *list_p2] + + if MR is not None: + arr.append(MR) + + nx = get_backend(*arr) + + npart1 = len(list_R1) + npart2 = len(list_R2) + + # compute marginals for global alignment + pR1 = nx.from_numpy(list_to_array([nx.sum(p) for p in list_p1])) + pR2 = nx.from_numpy(list_to_array([nx.sum(q) for q in list_p2])) + + # compute global alignment + if alpha == 1.: + res_global = gromov_wasserstein( + CR1, CR2, pR1, pR2, loss_fun='square_loss', log=log, + armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs) + + if log: + T_global, dist_global = res_global[0], res_global[1]['gw_dist'] + else: + T_global = res_global + + elif (alpha < 1.) and (alpha > 0.): + + res_global = fused_gromov_wasserstein( + MR, CR1, CR2, pR1, pR2, 'square_loss', alpha=alpha, log=log, + armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs) + + if log: + T_global, dist_global = res_global[0], res_global[1]['fgw_dist'] + else: + T_global = res_global + + else: + raise ValueError( + f""" + `alpha='{alpha}'` should be in ]0, 1]. + """) + + if log: + log_ = {} + log_['global dist'] = dist_global + + # compute local alignments + Ts_local = {} + list_p1_norm = [p / nx.sum(p) for p in list_p1] + list_p2_norm = [q / nx.sum(q) for q in list_p2] + + for i in range(npart1): + for j in range(npart2): + if T_global[i, j] != 0.: + res_1d = emd_1d(list_R1[i], list_R2[j], list_p1_norm[i], list_p2_norm[j], + metric='sqeuclidean', p=1., log=log) + if log: + T_local, log_local = res_1d + Ts_local[(i, j)] = T_local + log_[f'local dist ({i},{j})'] = log_local['cost'] + else: + Ts_local[(i, j)] = res_1d + + if build_OT: + T_rows = [] + for i in range(npart1): + list_Ti = [] + for j in range(npart2): + if T_global[i, j] == 0.: + T_local = nx.zeros((list_R1[i].shape[0], list_R2[j].shape[0]), type_as=T_global) + else: + T_local = T_global[i, j] * Ts_local[(i, j)] + list_Ti.append(T_local) + + Ti = nx.concatenate(list_Ti, axis=1) + T_rows.append(Ti) + T = nx.concatenate(T_rows, axis=0) + + else: + T = None + + if log: + return T_global, Ts_local, T, log_ + + else: + return T_global, Ts_local, T + + +def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., + random_state=0, nx=None): + """ + Partitioning a given graph with structure matrix :math:`\mathbf{C} \in R^{n \times n}` + into `npart` partitions either 'random', or using one of {'louvain', 'fluid'} + algorithms from networkx, or 'spectral' clustering from scikit-learn, + or (Fused) Gromov-Wasserstein projections from POT. + + Parameters + ---------- + C : array-like, shape (n, n) + Structure matrix. + npart : int, + number of partitions/clusters smaller than the number of nodes in + :math:`\mathbf{C}`. + part_method : str, optional. Default is 'random'. + Partitioning algorithm to use among {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. + 'random' for random sampling of points; 'louvain' and 'fluid' for graph + partitioning algorithm that works well on adjacency matrix, If the + louvain algorithm is used, `npart` is ignored; 'spectral' for spectral + clustering; '(F)GW' for (F)GW projection using sr(F)GW solvers. + F : array-like, shape (n, d), optional. (Default is None) + Optional feature matrix aligned with the graph structure. Only used if + `part_method="FGW"`. + alpha : float, optional. (Default is 1.) + Trade-off parameter between feature and structure matrices, taking + values in [0, 1] and only used if `F != None` and `part_method="FGW"`. + random_state: int, optional + Random seed for the partitioning algorithm. + nx : backend, optional + POT backend. + + Returns + ------- + part : array-like, shape (npart,) + Array of partition assignment for each node. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + nx = get_backend(C) + + n = C.shape[0] + C0 = C + + if (alpha != 1.) and (F is None): + raise ValueError("`alpha != 1` but node features are not provided.") + + if npart >= n: + warnings.warn( + "Requested number of partitions higher than the number of nodes" + "hence we enforce each node to be a partition.", + stacklevel=2 + ) + + part = np.arange(n) + + elif npart == 1: + part = np.zeros(n) + + elif part_method == 'random': + # randomly partition the space + random.seed(random_state) + part = list_to_array(random.choices(np.arange(npart), k=C.shape[0])) + + elif part_method == 'louvain': + C = nx.to_numpy(C0) + graph = from_numpy_array(C) + part_sets = louvain_communities(graph, seed=random_state) + part = np.zeros(n) + for iset_, set_ in enumerate(part_sets): + set_ = list(set_) + part[set_] = iset_ + + elif part_method == 'fluid': + C = nx.to_numpy(C0) + graph = from_numpy_array(C) + part_sets = asyn_fluidc(graph, npart, seed=random_state) + part = np.zeros(n) + for iset_, set_ in enumerate(part_sets): + set_ = list(set_) + part[set_] = iset_ + + elif part_method == 'spectral': + C = nx.to_numpy(C0) + sc = SpectralClustering(n_clusters=npart, + random_state=random_state, + affinity='precomputed').fit(C) + part = sc.labels_ + + elif part_method in ['GW', 'FGW']: + raise ValueError(f"`part_method == {part_method}` not implemented yet.") + + else: + raise ValueError( + f""" + Unknown `part_method='{part_method}'`. Use one of: + {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. + """) + return nx.from_numpy(part, type_as=C0) + + +def get_graph_representants(C, part, rep_method='pagerank', random_state=0, nx=None): + """ + Get representative node for each partition given by :math:`\mathbf{part} \in R^{n}` + of a graph with structure matrix :math:`\mathbf{C} \in R^{n \times n}`. + Selection is either done randomly or using 'pagerank' algorithm from networkx. + + Parameters + ---------- + C : array-like, shape (n, n) + structure matrix. + part : array-like, shape (n,) + Array of partition assignment for each node. + rep_method : str, optional. Default is 'pagerank'. + Selection method for representant in each partition. Can be either 'random' + i.e random sampling within each partition, or 'pagerank' to select a + node with maximal pagerank. + random_state: int, optional + Random seed for the partitioning algorithm + nx : backend, optional + POT backend + + Returns + ------- + rep_indices : list, shape (npart,) + indices for representative node of each partition sorted + according to partition identifiers. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + nx = get_backend(C, part) + + rep_indices = [] + part_ids = nx.unique(part) + n_part_ids = part_ids.shape[0] + if n_part_ids == C.shape[0]: + rep_indices = nx.arange(n_part_ids) + + elif rep_method == 'random': + random.seed(random_state) + for id_, part_id in enumerate(part_ids): + indices = nx.where(part == part_id)[0] + rep_indices.append(random.choice(indices)) + + elif rep_method == 'pagerank': + C0, part0 = C, part + C = nx.to_numpy(C0) + part = nx.to_numpy(part0) + part_ids = np.unique(part) + + for id_ in part_ids: + indices = np.where(part == id_)[0] + C_id = C[indices, :][:, indices] + graph = from_numpy_array(C_id) + pagerank_values = list(pagerank(graph).values()) + rep_idx = np.argmax(pagerank_values) + rep_indices.append(indices[rep_idx]) + + else: + raise ValueError( + f""" + Unknown `rep_method='{rep_method}'`. Use one of: + {'random', 'pagerank'}. + """) + + return rep_indices + + +def format_partitioned_graph(C, p, part, rep_indices, F=None, M=None, + alpha=1., nx=None): + """ + Format an attributed graph :math:`(\mathbf{C}, \mathbf{F}, \mathbf{p})` + with structure matrix :math:`(\mathbf{C} \in R^{n \times n}`, feature matrix + :math:`(\mathbf{F} \in R^{n \times d}` and node relative importance + :math:`(\mathbf{p} \in \Sigma_n`, into a partitioned attributed graph + taking into account partitions and representants :math:`\mathcal{P} = \left{(\mathbf{P_{i}}, \mathbf{r_{i}})\right}_i`. + + Parameters + ---------- + C : array-like, shape (n, n) + Structure matrix. + p : array-like, shape (n,), + Node distribution. + part : array-like, shape (n,) + Array of partition assignment for each node. + rep_indices : list of array-like of ints, shape (npart,) + indices for representative node of each partition sorted according to + partition identifiers. + F : array-like, shape (n, d), optional. (Default is None) + Optional feature matrix aligned with the graph structure. + M : array-like, shape (n, n), optional. (Default is None) + Optional pairwise similarity matrix between features. + alpha: float, optional. Default is 1. + Trade-off parameter in :math:`]0, 1]` between structure and features. + If `alpha = 1` features are ignored. This trade-off is taken into account + into the outputted relations between nodes and representants. + nx : backend, optional + POT backend + + Returns + ------- + CR : array-like, shape (npart, npart) + Structure matrix between partition representants. + list_R : list of npart arrays, + List of relations between a representant and nodes in its partition, + for each partition. + list_p : list of npart arrays, + List of node distributions within each partition. + FR : array-like, shape (npart, d), if `F != None`. + Feature matrix of representants. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + arr = [C, p, part] + if F is not None: + arr.append(F) + if M is not None: + arr.append(M) + + nx = get_backend(*arr) + + if alpha != 1.: + if (M is None) or (F is None): + raise ValueError( + f""" + `alpha == {alpha} != 1` but features information is not properly provided. + """) + + CR = C[rep_indices, :][:, rep_indices] + + if alpha != 1.: + C_new = alpha * C + (1 - alpha) * M + else: + C_new = C + + list_R, list_p = [], [] + + part_ids = nx.unique(part) + + for id_, part_id in enumerate(part_ids): + indices = nx.where(part == part_id)[0] + list_R.append(C_new[rep_indices[id_], indices]) + list_p.append(p[indices]) + + if F is None: + + return CR, list_R, list_p + else: + FR = F[rep_indices, :] + + return CR, list_R, list_p, FR + + +def quantized_fused_gromov_wasserstein( + C1, C2, npart1, npart2, p=None, q=None, C1_aux=None, C2_aux=None, + F1=None, F2=None, alpha=1., part_method='fluid', + rep_method='random', log=False, armijo=False, max_iter=1e4, + tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): + r""" + Returns the quantized Fused Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, + \mathbf{F_2}, \mathbf{q})`, whose samples are assigned to partitions and + representants :math:`\mathcal{P_1} = \{(\mathbf{P_{1, i}}, \mathbf{r_{1, i}})\}_{i \leq npart1}` + and :math:`\mathcal{P_2} = \{(\mathbf{P_{2, j}}, \mathbf{r_{2, j}})\}_{j \leq npart2}`. + + The function estimates the following optimization problem: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \alpha \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + (1-\alpha) \langle \mathbf{T}, \mathbf{D}(\mathbf{F_1}, \mathbf{F}_2) \rangle_F + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{T}_{|\mathbf{P_{1, i}}, \mathbf{P_{2, j}}} &= T^{g}_{ij} \mathbf{T}^{(i,j)} + + using a two-step strategy computing: i) a global alignment :math:`\mathbf{T}^{g}` + between representants across joint structure and feature spaces; + ii) local alignments :math:`\mathbf{T}^{(i, j)}` between partitions + :math:`\mathbf{P_{1, i}}` and :math:`\mathbf{P_{2, j}}` seen as 1D measures. + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{F_1}`: Feature matrix in the source space + - :math:`\mathbf{F_2}`: Feature matrix in the target space + - :math:`\mathbf{D}(\mathbf{F_1}, \mathbf{F_2})`: Pairwise euclidean distance matrix between features + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - :math:`L`: quadratic loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Structure matrix in the source space. + C2 : array-like, shape (nt, nt) + Structure matrix in the target space. + npart1 : int, + number of partition in the source space. + npart2 : int, + number of partition in the target space. + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + C1_aux : array-like, shape (ns, ns), optional. Default is None. + Auxiliary structure matrix in the source space to perform the partitioning + and representant selection. + C2_aux : array-like, shape (nt, nt), optional. Default is None. + Auxiliary structure matrix in the target space to perform the partitioning + and representant selection. + F1 : array-like, shape (ns, d), optional. Default is None. + Feature matrix in the source space. + F2 : array-like, shape (nt, d), optional. Default is None. + Feature matrix in the target space + alpha: float, optional. Default is 1. + FGW trade-off parameter in :math:`]0, 1]` between structure and features. + If `alpha = 1` features are ignored hence computing qGW, if `alpha=0` + structures are ignored and we compute the quantized Wasserstein transport. + part_method : str, optional. Default is 'spectral'. + Partitioning algorithm to use among {'random', 'louvain', 'fluid', + 'spectral', 'louvain_fused', 'fluid_fused', 'spectral_fused', 'GW', 'FGW'}. + If part_method in {'louvain_fused', 'fluid_fused', 'spectral_fused'}, + corresponding graph partitioning algorithm {'louvain', 'fluid', 'spectral'} + will be used on the modified structure matrix + :math:`\alpha \mathbf{C} + (1 - \alpha) \mathbf{D}(\mathbf{F})` where + :math:`\mathbf{D}(\mathbf{F})` is the pairwise euclidean matrix between features. + If part_method in {'GW', 'FGW'}, a (F)GW projection is used. + If the louvain algorithm is used, the requested number of partitions is + ignored. + rep_method : str, optional. Default is 'pagerank'. + Selection method for node representant in each partition. + Can be either 'random' i.e random sampling within each partition, + {'pagerank', 'pagerank_fused'} to select a node with maximal pagerank w.r.t + :math:`\mathbf{C}` or :math:`\alpha \mathbf{C} + (1 - \alpha) \mathbf{D}(\mathbf{F})`. + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + random_state: int, optional + Random seed for the partitioning algorithm + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T_global: array-like, shape (`npart1`, `npart2`) + Fused Gromov-Wasserstein alignment :math:`\mathbf{T}^{g}` between representants. + Ts_local: dict of local OT matrices. + Dictionary with keys :math:`(i, j)` corresponding to 1D OT between + :math:`\mathbf{P_{1, i}}` and :math:`\mathbf{P_{2, j}}` if :math:`T^{g}_{ij} \neq 0`. + T: array-like, shape `(ns, nt)` + Coupling between the two spaces. + log : dict + Convergence information for inner problems and qGW loss. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if (part_method in ['fluid', 'louvain', 'fluid_fused', 'louvain_fused'] or (rep_method in ['pagerank', 'pagerank_fused'])): + if not networkx_import: + warnings.warn( + f""" + Networkx is not installed, so part_method={part_method} and/or + rep_method={rep_method} cannot be used and are set to `random` + default methods. Consider installing Networkx to fix this. + """ + ) + part_method = 'random' + rep_method = 'random' + + if (part_method in ['spectral', 'spectral_fused']) and (not sklearn_import): + warnings.warn( + f""" + Scikit-learn is not installed, so part_method={part_method} and/or + rep_method={rep_method} cannot be used and are set to `random` + default methods. Consider installing Scikit-learn to fix this. + """ + ) + part_method = 'random' + rep_method = 'random' + + if (('fused' in part_method) or ('fused' in rep_method) or (part_method == 'FGW')): + if (F1 is None) or (F2 is None): + raise ValueError( + f""" + `part_method='{part_method}'` and/or `rep_method='{rep_method}'` + require feature matrices which are not provided as inputs. + """) + + arr = [C1, C2] + if C1_aux is not None: + arr.append(C1_aux) + else: + C1_aux = C1 + if C2_aux is not None: + arr.append(C2_aux) + else: + C2_aux = C2 + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(C2.shape[0], type_as=C1) + if F1 is not None: + arr.append(F1) + if F2 is not None: + arr.append(F1) + + nx = get_backend(*arr) + + DF1 = None + DF2 = None + # compute attributed graph partitions potentially using the auxiliary structure + if 'fused' in part_method: + + DF1 = dist(F1, F1) + DF2 = dist(F2, F2) + C1_new = alpha * C1_aux + (1 - alpha) * DF1 + C2_new = alpha * C2_aux + (1 - alpha) * DF2 + + part_method_ = part_method[:-6] + part1 = get_graph_partition(C1_new, npart1, part_method_, random_state=random_state, nx=nx) + part2 = get_graph_partition(C2_new, npart2, part_method_, random_state=random_state, nx=nx) + + else: + part1 = get_graph_partition(C1_aux, npart1, part_method, F1, alpha, random_state, nx) + part2 = get_graph_partition(C2_aux, npart2, part_method, F2, alpha, random_state, nx) + + if 'fused' in rep_method: + if DF1 is None: + DF1 = dist(F1, F1) + DF2 = dist(F2, F2) + C1_new = alpha * C1_aux + (1 - alpha) * DF1 + C2_new = alpha * C2_aux + (1 - alpha) * DF2 + + rep_method_ = rep_method[:-6] + + rep_indices1 = get_graph_representants(C1_new, part1, rep_method_, random_state, nx) + rep_indices2 = get_graph_representants(C2_new, part2, rep_method_, random_state, nx) + + else: + rep_indices1 = get_graph_representants(C1_aux, part1, rep_method, random_state, nx) + rep_indices2 = get_graph_representants(C2_aux, part2, rep_method, random_state, nx) + + # format partitions over (C1, F1) and (C2, F2) + if (F1 is None) and (F2 is None): + CR1, list_R1, list_p1 = format_partitioned_graph(C1, p, part1, rep_indices1, nx=nx) + CR2, list_R2, list_p2 = format_partitioned_graph(C2, q, part2, rep_indices2, nx=nx) + + MR = None + else: + if DF1 is None: + DF1 = dist(F1, F1) + DF2 = dist(F2, F2) + + CR1, list_R1, list_p1, FR1 = format_partitioned_graph(C1, p, part1, rep_indices1, F1, DF1, alpha, nx) + CR2, list_R2, list_p2, FR2 = format_partitioned_graph(C2, q, part2, rep_indices2, F2, DF2, alpha, nx) + + MR = dist(FR1, FR2) + # call to partitioned quantized fused gromov-wasserstein solver + + res = quantized_fused_gromov_wasserstein_partitioned( + CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR, alpha, build_OT=True, + log=log, armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, + tol_abs=tol_abs, nx=nx, **kwargs) + + if log: + T_global, Ts_local, T, log_ = res + + # compute the transport cost on structures + constC, hC1, hC2 = init_matrix(C1, C2, p, q, 'square_loss', nx) + structure_cost = gwloss(constC, hC1, hC2, T, nx) + + if alpha != 1.: + M = dist(F1, F2) + feature_cost = nx.sum(M * T) + else: + feature_cost = 0. + + log_['qFGW_dist'] = alpha * structure_cost + (1 - alpha) * feature_cost + return T_global, Ts_local, T, log_ + + else: + T_global, Ts_local, T = res + + return T_global, Ts_local, T + + +def get_partition_and_representants_samples( + X, npart, method='kmeans', random_state=0, nx=None): + """ + Compute `npart` partitions and representants over samples :math:`\mathbf{X} \in R^{n \times d}` + using either a random or a kmeans algorithm. + + Parameters + ---------- + X : array-like, shape (n, d) + Samples endowed with an euclidean geometry. + npart : int, + number of partitions smaller than the number of samples in + :math:`\mathbf{X}`. + method : str, optional. Default is 'kmeans'. + Partitioning and representant selection algorithms to use among + {'random', 'kmeans'}. 'random' for random sampling of points; 'kmeans' + for k-means clustering using scikit-learn implementation where closest + points to centroids are considered as representants. + random_state: int, optional + Random seed for the partitioning algorithm. + nx : backend, optional + POT backend. + + Returns + ------- + part : array-like, shape (npart,) + Array of partition assignment for each node. + + rep_indices : list, shape (npart,) + indices for representative node of each partition sorted + according to partition identifiers. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + nx = get_backend(X) + + n = X.shape[0] + X0 = X + + if npart >= n: + warnings.warn( + "Requested number of partitions higher than the number of nodes" + "hence we enforce each node to be a partition.", + stacklevel=2 + ) + + part = nx.arange(n) + rep_indices = nx.arange(n) + + elif npart == 1: + random.seed(random_state) + part = nx.zeros(n) + rep_indices = [random.choice(nx.arange(n))] + + elif method == 'random': + # randomly partition the space + random.seed(random_state) + part = list_to_array(random.choices(np.arange(npart), k=X.shape[0])) + part = nx.from_numpy(part, type_as=X0) + + # randomly select representant in each partition + rep_indices = [] + part_ids = nx.unique(part) + for id_, part_id in enumerate(part_ids): + indices = nx.where(part == part_id)[0] + rep_indices.append(random.choice(indices)) + + elif method == 'kmeans': + X = nx.to_numpy(X0) + km = KMeans(n_clusters=npart, random_state=random_state).fit(X) + part = nx.from_numpy(km.labels_, type_as=X0) + + rep_indices = [] + for part_id in range(npart): + indices = nx.where(part == part_id)[0] + dists = dist(X[indices], km.cluster_centers_[part_id][None, :]) + best_idx = indices[dists.argmin()] + rep_indices.append(best_idx) + + else: + raise ValueError( + f""" + Unknown `method='{method}'`. Use one of: {'random', 'kmeans'} + """) + + return part, rep_indices + + +def format_partitioned_samples( + X, p, part, rep_indices, F=None, alpha=1., nx=None): + """ + Format an attributed graph :math:`(\mathbf{D}(\mathbf{X}), \mathbf{F}, \mathbf{p})` + with euclidean structure matrix :math:`(\mathbf{D}(\mathbf{X}) \in R^{n \times n}`, + feature matrix :math:`(\mathbf{F} \in R^{n \times d}` and node relative importance + :math:`(\mathbf{p} \in \Sigma_n`, into a partitioned attributed graph + taking into account partitions and representants :math:`\mathcal{P} = \left{(\mathbf{P_{i}}, \mathbf{r_{i}})\right}_i`. + + Parameters + ---------- + X : array-like, shape (n, d) + Structure matrix. + p : array-like, shape (n,), + Node distribution. + part : array-like, shape (n,) + Array of partition assignment for each node. + rep_indices : list of array-like of ints, shape (npart,) + indices for representative node of each partition sorted according to + partition identifiers. + F : array-like, shape (n, p), optional. (Default is None) + Optional feature matrix aligned with the samples. + alpha: float, optional. Default is 1. + Trade-off parameter in :math:`]0, 1]` between structure and features. + If `alpha = 1` features are ignored. This trade-off is taken into account + into the outputted relations between nodes and representants. + nx : backend, optional + POT backend + + Returns + ------- + CR : array-like, shape (npart, npart) + Structure matrix between partition representants. + list_R : list of npart arrays, + List of relations between a representant and nodes in its partition, + for each partition. + list_p : list of npart arrays, + List of node distributions within each partition. + FR : array-like, shape (npart, d), if `F != None`. + Feature matrix of representants. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + if nx is None: + arr = [X, p, part] + if F is not None: + arr.append(F) + + nx = get_backend(*arr) + + if alpha != 1.: + if F is None: + raise ValueError( + f""" + `alpha == {alpha} != 1` but features information is not properly provided. + """) + + XR = X[rep_indices, :] + CR = dist(XR, XR) + + list_R, list_p = [], [] + + part_ids = nx.unique(part) + + for id_, part_id in enumerate(part_ids): + indices = nx.where(part == part_id)[0] + structure_R = dist(X[indices], X[rep_indices[id_]][None, :]) + + if alpha != 1: + features_R = dist(F[indices], F[rep_indices[id_]][None, :]) + else: + features_R = 0. + + list_R.append(alpha * structure_R + (1 - alpha) * features_R) + list_p.append(p[indices]) + + if F is None: + + return CR, list_R, list_p + else: + FR = F[rep_indices, :] + + return CR, list_R, list_p, FR + + +def quantized_fused_gromov_wasserstein_samples( + X1, X2, npart1, npart2, p=None, q=None, F1=None, F2=None, alpha=1., + method='kmeans', log=False, armijo=False, max_iter=1e4, + tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): + r""" + Returns the quantized Fused Gromov-Wasserstein transport between samples + endowed with their respective euclidean geometry :math:`(\mathbf{D}(\mathbf{X_1}), \mathbf{F_1}, \mathbf{p})` + and :math:`(\mathbf{D}(\mathbf{X_1}), \mathbf{F_2}, \mathbf{q})`, whose samples are assigned to partitions and + representants :math:`\mathcal{P_1} = \{(\mathbf{P_{1, i}}, \mathbf{r_{1, i}})\}_{i \leq npart1}` + and :math:`\mathcal{P_2} = \{(\mathbf{P_{2, j}}, \mathbf{r_{2, j}})\}_{j \leq npart2}`. + + The function estimates the following optimization problem: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \alpha \sum_{i,j,k,l} + L(\mathbf{D}(\mathbf{X_1})_{i,k}, \mathbf{D}(\mathbf{X_2})_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + (1-\alpha) \langle \mathbf{T}, \mathbf{D}(\mathbf{F_1}, \mathbf{F}_2) \rangle_F + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{T}_{|\mathbf{P_{1, i}}, \mathbf{P_{2, j}}} &= T^{g}_{ij} \mathbf{T}^{(i,j)} + + using a two-step strategy computing: i) a global alignment :math:`\mathbf{T}^{g}` + between representants across joint structure and feature spaces; + ii) local alignments :math:`\mathbf{T}^{(i, j)}` between partitions + :math:`\mathbf{P_{1, i}}` and :math:`\mathbf{P_{2, j}}` seen as 1D measures. + + Where : + + - :math:`\mathbf{X_1}`: Samples in the source space + - :math:`\mathbf{X_2}`: Samples in the target space + - :math:`\mathbf{F_1}`: Feature matrix in the source space + - :math:`\mathbf{F_2}`: Feature matrix in the target space + - :math:`\mathbf{D}(\mathbf{F_1}, \mathbf{F_2})`: Pairwise euclidean distance matrix between features + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - :math:`L`: quadratic loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + + Parameters + ---------- + X1 : array-like, shape (ns, ds) + Samples in the source space. + X2 : array-like, shape (nt, dt) + Samples in the target space. + npart1 : int, + number of partition in the source space. + npart2 : int, + number of partition in the target space. + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + F1 : array-like, shape (ns, d), optional. Default is None. + Feature matrix in the source space. + F2 : array-like, shape (nt, d), optional. Default is None. + Feature matrix in the target space + alpha: float, optional. Default is 1. + FGW trade-off parameter in :math:`]0, 1]` between structure and features. + If `alpha = 1` features are ignored hence computing qGW, if `alpha=0` + structures are ignored and we compute the quantized Wasserstein transport. + method : str, optional. Default is 'kmeans'. + Partitioning and representant selection algorithms to use among + {'random', 'kmeans', 'kmeans_fused'}. + If `part_method == 'kmeans_fused'`, kmeans is performed on augmented + samples :math:`[\alpha \mathbf{X}; (1 - \alpha) \mathbf{F}]`. + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the step of the line-search is found via an armijo research. Else closed form is used. + If there are convergence issues use False. + max_iter : int, optional + Max number of iterations + tol_rel : float, optional + Stop threshold on relative error (>0) + tol_abs : float, optional + Stop threshold on absolute error (>0) + random_state: int, optional + Random seed for the partitioning algorithm + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + T_global: array-like, shape (`npart1`, `npart2`) + Fused Gromov-Wasserstein alignment :math:`\mathbf{T}^{g}` between representants. + Ts_local: dict of local OT matrices. + Dictionary with keys :math:`(i, j)` corresponding to 1D OT between + :math:`\mathbf{P_{1, i}}` and :math:`\mathbf{P_{2, j}}` if :math:`T^{g}_{ij} \neq 0`. + T: array-like, shape `(ns, nt)` + Coupling between the two spaces. + log : dict + Convergence information for inner problems and qGW loss. + + References + ---------- + .. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). + Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. + + """ + + if (method in ['kmeans', 'kmeans_fused']) and (not sklearn_import): + warnings.warn( + f""" + Scikit-learn is not installed, so method={method} cannot be used + and is set to `random` default methods. Consider installing + Scikit-learn to fix this. + """ + ) + method = 'random' + + if ('fused' in method) and ((F1 is None) or (F2 is None)): + raise ValueError( + f""" + `method='{method}'` requires feature matrices which are not provided as inputs. + """) + + arr = [X1, X2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(X1.shape[0], type_as=X1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(X2.shape[0], type_as=X1) + if F1 is not None: + arr.append(F1) + if F2 is not None: + arr.append(F1) + + nx = get_backend(*arr) + + # compute attributed partitions and representants + if ('fused' in method) and (alpha != 1.): + X1_new = nx.concatenate([alpha * X1, (1 - alpha) * F1], axis=1) + X2_new = nx.concatenate([alpha * X2, (1 - alpha) * F2], axis=1) + method_ = method[:-6] + else: + X1_new, X2_new = X1, X2 + method_ = method + part1, rep_indices1 = get_partition_and_representants_samples( + X1_new, npart1, method_, random_state, nx) + part2, rep_indices2 = get_partition_and_representants_samples( + X2_new, npart2, method_, random_state, nx) + # format partitions over (C1, F1) and (C2, F2) + + if (F1 is None) and (F2 is None): + CR1, list_R1, list_p1 = format_partitioned_samples( + X1, p, part1, rep_indices1, nx=nx) + CR2, list_R2, list_p2 = format_partitioned_samples( + X2, q, part2, rep_indices2, nx=nx) + + MR = None + else: + CR1, list_R1, list_p1, FR1 = format_partitioned_samples( + X1, p, part1, rep_indices1, F1, alpha, nx) + CR2, list_R2, list_p2, FR2 = format_partitioned_samples( + X2, q, part2, rep_indices2, F2, alpha, nx) + + MR = dist(FR1, FR2) + + # call to partitioned quantized fused gromov-wasserstein solver + + res = quantized_fused_gromov_wasserstein_partitioned( + CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR, alpha, build_OT=True, + log=log, armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, + tol_abs=tol_abs, nx=nx, **kwargs) + + if log: + T_global, Ts_local, T, log_ = res + + C1 = dist(X1, X1) + C2 = dist(X2, X2) + + # compute the transport cost on structures + constC, hC1, hC2 = init_matrix(C1, C2, p, q, 'square_loss', nx) + structure_cost = gwloss(constC, hC1, hC2, T, nx) + + if alpha != 1.: + M = dist(F1, F2) + feature_cost = nx.sum(M * T) + else: + feature_cost = 0. + + log_['qFGW_dist'] = alpha * structure_cost + (1 - alpha) * feature_cost + return T_global, Ts_local, T, log_ + + else: + T_global, Ts_local, T = res + + return T_global, Ts_local, T diff --git a/test/gromov/test_quantized.py b/test/gromov/test_quantized.py new file mode 100644 index 000000000..a864a8a46 --- /dev/null +++ b/test/gromov/test_quantized.py @@ -0,0 +1,377 @@ +"""Tests for gromov._quantized.py """ + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest + +import ot + +from ot.gromov._quantized import ( + networkx_import, sklearn_import) + + +def test_quantized_gw(nx): + n_samples = 30 # nb samples + + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = (C1 + C1.T) / 2. + + C2 = rng.uniform(low=10., high=20., size=(n_samples, n_samples)) + C2 = (C2 + C2.T) / 2. + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + npart2 = 3 + + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) + + for npart1 in [1, n_samples + 1, 2]: + log_tests = [True, False, False, True, True, False] + + pairs_part_rep = [('random', 'random')] + if networkx_import: + pairs_part_rep += [('louvain', 'random'), ('fluid', 'pagerank')] + if sklearn_import: + pairs_part_rep += [('spectral', 'random')] + + count_mode = 0 + + for part_method, rep_method in pairs_part_rep: + log_ = log_tests[count_mode] + count_mode += 1 + + res = ot.gromov.quantized_fused_gromov_wasserstein( + C1, C2, npart1, npart2, p, None, C1, None, part_method=part_method, + rep_method=rep_method, log=log_) + + resb = ot.gromov.quantized_fused_gromov_wasserstein( + C1b, C2b, npart1, npart2, None, qb, None, C2b, part_method=part_method, + rep_method=rep_method, log=log_) + + if log_: + T_global, Ts_local, T, log = res + T_globalb, Ts_localb, Tb, logb = resb + else: + T_global, Ts_local, T = res + T_globalb, Ts_localb, Tb = resb + + Tb = nx.to_numpy(Tb) + # check constraints + np.testing.assert_allclose(T, Tb, atol=1e-06) + np.testing.assert_allclose( + p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose( + q, Tb.sum(0), atol=1e-06) # cf convergence gromov + + if log_: + for key in log.keys(): + # The inner test T_global[i, j] != 0. can lead to different + # computation of 1D OT computations between partition depending + # on the different float errors across backend + if key in logb.keys(): + np.testing.assert_allclose(log[key], logb[key], atol=1e-06) + + +def test_quantized_fgw(nx): + n_samples = 30 # nb samples + + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = (C1 + C1.T) / 2. + + F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + + C2 = rng.uniform(low=10., high=20., size=(n_samples, n_samples)) + C2 = (C2 + C2.T) / 2. + + F2 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + npart1 = 2 + npart2 = 3 + + C1b, C2b, F1b, F2b, pb, qb = nx.from_numpy(C1, C2, F1, F2, p, q) + + log_tests = [True, False, False, True, True, False] + + pairs_part_rep = [] + if networkx_import: + pairs_part_rep += [('louvain_fused', 'pagerank'), + ('louvain', 'pagerank_fused'), + ('fluid_fused', 'pagerank_fused')] + if sklearn_import: + pairs_part_rep += [('spectral_fused', 'random')] + + pairs_part_rep += [('random', 'random')] + count_mode = 0 + + alpha = 0.5 + + for part_method, rep_method in pairs_part_rep: + log_ = log_tests[count_mode] + count_mode += 1 + + res = ot.gromov.quantized_fused_gromov_wasserstein( + C1, C2, npart1, npart2, p, None, C1, None, F1, F2, alpha, + part_method, rep_method, log_) + + resb = ot.gromov.quantized_fused_gromov_wasserstein( + C1b, C2b, npart1, npart2, None, qb, None, C2b, F1b, F2b, alpha, + part_method, rep_method, log_) + + if log_: + T_global, Ts_local, T, log = res + T_globalb, Ts_localb, Tb, logb = resb + else: + T_global, Ts_local, T = res + T_globalb, Ts_localb, Tb = resb + + Tb = nx.to_numpy(Tb) + # check constraints + np.testing.assert_allclose(T, Tb, atol=1e-06) + np.testing.assert_allclose( + p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose( + q, Tb.sum(0), atol=1e-06) # cf convergence gromov + + if log_: + for key in log.keys(): + # The inner test T_global[i, j] != 0. can lead to different + # computation of 1D OT computations between partition depending + # on the different float errors across backend + if key in logb.keys(): + np.testing.assert_allclose(log[key], logb[key], atol=1e-06) + + # complementary tests for utils functions + DF1b = ot.dist(F1b, F1b) + DF2b = ot.dist(F2b, F2b) + C1b_new = alpha * C1b + (1 - alpha) * DF1b + C2b_new = alpha * C2b + (1 - alpha) * DF2b + + part1b = ot.gromov.get_graph_partition( + C1b_new, npart1, part_method=pairs_part_rep[-1][0], random_state=0) + part2b = ot.gromov._quantized.get_graph_partition( + C2b_new, npart2, part_method=pairs_part_rep[-1][0], random_state=0) + + rep_indices1b = ot.gromov.get_graph_representants( + C1b, part1b, rep_method=pairs_part_rep[-1][1], random_state=0) + rep_indices2b = ot.gromov.get_graph_representants( + C2b, part2b, rep_method=pairs_part_rep[-1][1], random_state=0) + + CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_graph( + C1b, pb, part1b, rep_indices1b, F1b, DF1b, alpha) + CR2b, list_R2b, list_p2b, FR2b = ot.gromov.format_partitioned_graph( + C2b, qb, part2b, rep_indices2b, F2b, DF2b, alpha) + + MRb = ot.dist(FR1b, FR2b) + + T_globalb, Ts_localb, _ = ot.gromov.quantized_fused_gromov_wasserstein_partitioned( + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False) + + T_globalb = nx.to_numpy(T_globalb) + np.testing.assert_allclose(T_global, T_globalb, atol=1e-06) + + for key in Ts_localb.keys(): + T_localb = nx.to_numpy(Ts_localb[key]) + np.testing.assert_allclose(Ts_local[key], T_localb, atol=1e-06) + + # tests for edge cases of the graph partitioning + for method in ['unknown_method', 'GW', 'FGW']: + with pytest.raises(ValueError): + ot.gromov.get_graph_partition( + C1b, npart1, part_method=method, random_state=0) + + with pytest.raises(ValueError): + ot.gromov.get_graph_partition( + C1b, npart1, part_method=method, alpha=0.5, F=None, random_state=0) + + # tests for edge cases of the representant selection + with pytest.raises(ValueError): + ot.gromov.get_graph_representants( + C1b, part1b, rep_method='unknown_method', random_state=0) + + # tests for edge cases of the format_partitioned_graph function + with pytest.raises(ValueError): + CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_graph( + C1b, pb, part1b, rep_indices1b, F1b, None, alpha) + + # Tests in qFGW solvers + # for non admissible values of alpha + with pytest.raises(ValueError): + ot.gromov.quantized_fused_gromov_wasserstein_partitioned( + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, 0, build_OT=False) + + # for non-consistent feature information provided + with pytest.raises(ValueError): + ot.gromov.quantized_fused_gromov_wasserstein( + C1, C2, npart1, npart2, p, q, None, None, F1, None, 0.5, + 'spectral_fused', 'random', log_) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_quantized_gw_samples(nx): + n_samples_1 = 15 # nb samples + n_samples_2 = 20 # nb samples + + rng = np.random.RandomState(0) + X1 = rng.uniform(low=0., high=10, size=(n_samples_1, 2)) + X2 = rng.uniform(low=0., high=10, size=(n_samples_2, 4)) + + p = ot.unif(n_samples_1) + q = ot.unif(n_samples_2) + + npart1 = 2 + npart2 = 3 + + X1b, X2b, pb, qb = nx.from_numpy(X1, X2, p, q) + + log_tests = [True, False, True] + methods = ['random'] + if sklearn_import: + methods += ['kmeans'] + + count_mode = 0 + alpha = 1. + + for method in methods: + log_ = log_tests[count_mode] + count_mode += 1 + + res = ot.gromov.quantized_fused_gromov_wasserstein_samples( + X1, X2, npart1, npart2, p, None, None, None, alpha, method, log_) + + resb = ot.gromov.quantized_fused_gromov_wasserstein_samples( + X1b, X2b, npart1, npart2, None, qb, None, None, alpha, method, log_) + + if log_: + T_global, Ts_local, T, log = res + T_globalb, Ts_localb, Tb, logb = resb + else: + T_global, Ts_local, T = res + T_globalb, Ts_localb, Tb = resb + + Tb = nx.to_numpy(Tb) + # check constraints + np.testing.assert_allclose(T, Tb, atol=1e-06) + np.testing.assert_allclose( + p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose( + q, Tb.sum(0), atol=1e-06) # cf convergence gromov + + if log_: + for key in log.keys(): + # The inner test T_global[i, j] != 0. can lead to different + # computation of 1D OT computations between partition depending + # on the different float errors across backend + if key in logb.keys(): + np.testing.assert_allclose(log[key], logb[key], atol=1e-06) + + # tests for edge cases of the representant selection + with pytest.raises(ValueError): + ot.gromov.get_partition_and_representants_samples( + X1, npart1, method='unknown_method', random_state=0) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_quantized_fgw_samples(nx): + n_samples_1 = 20 # nb samples + n_samples_2 = 30 # nb samples + + rng = np.random.RandomState(0) + X1 = rng.uniform(low=0., high=10, size=(n_samples_1, 2)) + X2 = rng.uniform(low=0., high=10, size=(n_samples_2, 4)) + + F1 = rng.uniform(low=0., high=10, size=(n_samples_1, 3)) + F2 = rng.uniform(low=0., high=10, size=(n_samples_2, 3)) + + p = ot.unif(n_samples_1) + q = ot.unif(n_samples_2) + + npart1 = 2 + npart2 = 3 + + X1b, X2b, F1b, F2b, pb, qb = nx.from_numpy(X1, X2, F1, F2, p, q) + + methods = [] + if sklearn_import: + methods += ['kmeans', 'kmeans_fused'] + methods += ['random'] + + alpha = 0.5 + + for npart1 in [1, n_samples_1 + 1, 2]: + log_tests = [True, False, True] + count_mode = 0 + + for method in methods: + log_ = log_tests[count_mode] + count_mode += 1 + + res = ot.gromov.quantized_fused_gromov_wasserstein_samples( + X1, X2, npart1, npart2, p, None, F1, F2, alpha, method, log_) + + resb = ot.gromov.quantized_fused_gromov_wasserstein_samples( + X1b, X2b, npart1, npart2, None, qb, F1b, F2b, alpha, method, log_) + + if log_: + T_global, Ts_local, T, log = res + T_globalb, Ts_localb, Tb, logb = resb + else: + T_global, Ts_local, T = res + T_globalb, Ts_localb, Tb = resb + + Tb = nx.to_numpy(Tb) + # check constraints + np.testing.assert_allclose(T, Tb, atol=1e-06) + np.testing.assert_allclose( + p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose( + q, Tb.sum(0), atol=1e-06) # cf convergence gromov + + if log_: + for key in log.keys(): + # The inner test T_global[i, j] != 0. can lead to different + # computation of 1D OT computations between partition depending + # on the different float errors across backend + if key in logb.keys(): + np.testing.assert_allclose(log[key], logb[key], atol=1e-06) + + # complementary tests for utils functions + part1b, rep_indices1 = ot.gromov.get_partition_and_representants_samples( + X1b, npart1, method=method, random_state=0) + part2b, rep_indices2 = ot.gromov.get_partition_and_representants_samples( + X2b, npart2, method=method, random_state=0) + + CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_samples( + X1b, pb, part1b, rep_indices1, F1b, alpha) + CR2b, list_R2b, list_p2b, FR2b = ot.gromov.format_partitioned_samples( + X2b, qb, part2b, rep_indices2, F2b, alpha) + + MRb = ot.dist(FR1b, FR2b) + + T_globalb, Ts_localb, _ = ot.gromov.quantized_fused_gromov_wasserstein_partitioned( + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False) + + T_globalb = nx.to_numpy(T_globalb) + np.testing.assert_allclose(T_global, T_globalb, atol=1e-06) + + for key in Ts_localb.keys(): + T_localb = nx.to_numpy(Ts_localb[key]) + np.testing.assert_allclose(Ts_local[key], T_localb, atol=1e-06) + + # tests for edge cases of the format_partitioned_graph function + with pytest.raises(ValueError): + CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_samples( + X1b, pb, part1b, rep_indices1, None, alpha) + + # for non-consistent feature information provided + with pytest.raises(ValueError): + ot.gromov.quantized_fused_gromov_wasserstein_samples( + X1, X2, npart1, npart2, p, None, None, F2, alpha, 'fused_spectral', log_) From 61d751b0bac9625e5ceb7e3a320f6cf2b6800052 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 19 Jun 2024 02:38:55 -0500 Subject: [PATCH 16/30] [BLD] Add support for NumPy 2.0 wheels (#629) * build: Add support for NumPy 2.0 wheels [build wheels] * Add numpy>=2.0.0 to build-system requires to build NumPy 1.x and 2.x compatible wheels. - c.f. https://numpy.org/doc/stable/dev/depending_on_numpy.html#numpy-2-0-specific-advice - As NumPy 2.0 is Python 3.9+, also need to conditionally support 'oldest-supported-numpy' for Python 3.8. * Remove wheel from build-system requires as it is never required and injected automatically by setuptools only when needed. - c.f. https://learn.scientific-python.org/development/guides/packaging-classic/ * build: Remove setup_requires metadata * Remove setup_requires as it is deprecated and should defer to PEP 517/518. - c.f. https://learn.scientific-python.org/development/guides/packaging-classic/#pep-517518-support-high-priority * fix: Use np.inf to avoid AttributeError * Avoids: AttributeError: `np.infty` was removed in the NumPy 2.0 release. Use `np.inf` instead. AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead. * docs: Add NumPy 2.0 support to release notes --- RELEASES.md | 1 + ot/da.py | 12 ++++++------ ot/regpath.py | 2 +- pyproject.toml | 9 +++++++-- setup.py | 1 - 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 51075c973..3908d079c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,6 +3,7 @@ ## 0.9.4dev #### New features ++ NumPy 2.0 support is added (PR #629) + New quantized FGW solvers `ot.gromov.quantized_fused_gromov_wasserstein`, `ot.gromov.quantized_fused_gromov_wasserstein_samples` and `ot.gromov.quantized_fused_gromov_wasserstein_partitioned` (PR #603) + `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607). + Continuous entropic mapping (PR #613) diff --git a/ot/da.py b/ot/da.py index e4adaa546..2b28260ef 100644 --- a/ot/da.py +++ b/ot/da.py @@ -497,7 +497,7 @@ class label if (ys is not None) and (yt is not None): - if self.limit_max != np.infty: + if self.limit_max != np.inf: self.limit_max = self.limit_max * nx.max(self.cost_) # missing_labels is a (ns, nt) matrix of {0, 1} such that @@ -519,7 +519,7 @@ class label cost_correction = label_match * missing_labels * self.limit_max # this operation is necessary because 0 * Inf = NAN # thus is irrelevant when limit_max is finite - cost_correction = nx.nan_to_num(cost_correction, -np.infty) + cost_correction = nx.nan_to_num(cost_correction, -np.inf) self.cost_ = nx.maximum(self.cost_, cost_correction) # distribution estimation @@ -1067,7 +1067,7 @@ class SinkhornTransport(BaseTransport): method from :ref:`[66] ` and :ref:`[19] `. - limit_max: float, optional (default=np.infty) + limit_max: float, optional (default=np.inf) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an cost defined by this variable @@ -1109,7 +1109,7 @@ def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='continuous', limit_max=np.infty): + out_of_sample_map='continuous', limit_max=np.inf): if out_of_sample_map not in ['ferradans', 'continuous']: raise ValueError('Unknown out_of_sample_map method') @@ -1417,7 +1417,7 @@ class SinkhornLpl1Transport(BaseTransport): The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is "ferradans" which uses the method proposed in :ref:`[6] `. - limit_max: float, optional (default=np.infty) + limit_max: float, optional (default=np.inf) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit a cost defined by limit_max. @@ -1450,7 +1450,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, tol=10e-9, verbose=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=np.infty): + out_of_sample_map='ferradans', limit_max=np.inf): self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter diff --git a/ot/regpath.py b/ot/regpath.py index 8a9b6d886..5e32e4fd1 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -762,7 +762,7 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, active_index.append(i * m + j) gamma_list = [] t_list = [] - current_gamma = np.Inf + current_gamma = np.inf augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr) add_col = np.array([]) id_pop = -1 diff --git a/pyproject.toml b/pyproject.toml index 378920623..a83a59332 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,8 @@ [build-system] -requires = ["setuptools", "wheel", "oldest-supported-numpy", "cython>=0.23"] -build-backend = "setuptools.build_meta" \ No newline at end of file +requires = [ + "setuptools>=42", + "oldest-supported-numpy; python_version < '3.9'", + "numpy>=2.0.0; python_version >= '3.9'", + "cython>=0.23" +] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 72b1488b2..1d0b6fceb 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,6 @@ license='MIT', scripts=[], data_files=[], - setup_requires=["oldest-supported-numpy", "cython>=0.23"], install_requires=["numpy>=1.16", "scipy>=1.6"], python_requires=">=3.6", classifiers=[ From 36b4c0ac397f5391d1eefc92446ebc217c0467e0 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 19 Jun 2024 11:44:15 -0500 Subject: [PATCH 17/30] [CI] Remove redundant CIBW_BEFORE_BUILD (#631) * The build system should define all required build components and make them available at wheel build time. As this pip install step is not required, remove CIBW_BEFORE_BUILD. From the docs https://cibuildwheel.pypa.io/en/stable/options/#before-build > If dependencies are required to build your wheel (for example if you > include a header from a Python module), instead of using this command, > we recommend adding requirements to a pyproject.toml file's > build-system.requires array instead. --- .github/workflows/build_wheels_weekly.yml | 3 +-- Makefile | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index 6b2f124fa..4746a3676 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -7,7 +7,7 @@ on: push: branches: - "master" - + jobs: build_wheels: name: ${{ matrix.os }} @@ -40,7 +40,6 @@ jobs: - name: Build wheels env: CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36* *i686" # remove pypy on mac and win (wrong version) - CIBW_BEFORE_BUILD: "pip install numpy cython" CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU CIBW_ARCHS_MACOS: x86_64 universal2 arm64 run: | diff --git a/Makefile b/Makefile index 7a5cbe1be..bef487a52 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ aautopep8 : autopep8 -air test ot examples --jobs -1 wheels : - CIBW_BEFORE_BUILD="pip install numpy cython" cibuildwheel --platform linux --output-dir dist + cibuildwheel --platform linux --output-dir dist dist : wheels $(PYTHON) setup.py sdist From 246de6bc97a8fb6bb884ea5ebbc0a9e409db73d4 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Thu, 20 Jun 2024 01:53:06 -0500 Subject: [PATCH 18/30] [BLD] Update requires-python metadata to Python 3.7 (#630) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [BLD] Update requires-python metadata to Python 3.7 * Update requires-python metadata through setuptools's python_requires to reflect that only Python 3.7+ is distributed on PyPI and only Python 3.8+ is tested in CI and so Python 3.6 is not supported. - c.f. https://peps.python.org/pep-0621/#requires-python - The use of requires-python is to provide guards to keep older CPython versions from installing releases that could contain unrunnable code. * Python 3.6 is also EOL ┌───────┬────────────┬─────────┬────────────────┬────────────┬────────────┐ │ cycle │ release │ latest │ latest release │ support │ eol │ ├───────┼────────────┼─────────┼────────────────┼────────────┼────────────┤ │ 3.12 │ 2023-10-02 │ 3.12.4 │ 2024-06-06 │ 2025-04-02 │ 2028-10-31 │ │ 3.11 │ 2022-10-24 │ 3.11.9 │ 2024-04-02 │ 2024-04-01 │ 2027-10-31 │ │ 3.10 │ 2021-10-04 │ 3.10.14 │ 2024-03-19 │ 2023-04-05 │ 2026-10-31 │ │ 3.9 │ 2020-10-05 │ 3.9.19 │ 2024-03-19 │ 2022-05-17 │ 2025-10-31 │ │ 3.8 │ 2019-10-14 │ 3.8.19 │ 2024-03-19 │ 2021-05-03 │ 2024-10-31 │ │ 3.7 │ 2018-06-26 │ 3.7.17 │ 2023-06-05 │ 2020-06-27 │ 2023-06-27 │ │ 3.6 │ 2016-12-22 │ 3.6.15 │ 2021-09-03 │ 2018-12-24 │ 2021-12-23 │ │ 3.5 │ 2015-09-12 │ 3.5.10 │ 2020-09-05 │ False │ 2020-09-30 │ │ 3.4 │ 2014-03-15 │ 3.4.10 │ 2019-03-18 │ False │ 2019-03-18 │ │ 3.3 │ 2012-09-29 │ 3.3.7 │ 2017-09-19 │ False │ 2017-09-29 │ │ 3.2 │ 2011-02-20 │ 3.2.6 │ 2014-10-12 │ False │ 2016-02-20 │ │ 3.1 │ 2009-06-26 │ 3.1.5 │ 2012-04-06 │ False │ 2012-04-09 │ │ 3.0 │ 2008-12-03 │ 3.0.1 │ 2009-02-12 │ False │ 2009-06-27 │ │ 2.7 │ 2010-07-03 │ 2.7.18 │ 2020-04-19 │ False │ 2020-01-01 │ │ 2.6 │ 2008-10-01 │ 2.6.9 │ 2013-10-29 │ False │ 2013-10-29 │ └───────┴────────────┴─────────┴────────────────┴────────────┴────────────┘ * [DOC] Add Python 3.11 and Python 3.12 PyPI classifier metadata. * Add Python 3.11 and Python 3.12 as PyPI trove classifiers. --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1d0b6fceb..313d25863 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ scripts=[], data_files=[], install_requires=["numpy>=1.16", "scipy>=1.6"], - python_requires=">=3.6", + python_requires=">=3.7", classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', @@ -92,10 +92,11 @@ 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Information Analysis', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', ] ) From 1f7d5476c48e04de32b6ec2a4ac94a72ba0ac2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 20 Jun 2024 15:07:41 +0200 Subject: [PATCH 19/30] [MRG] correct bugs with gw barycenter on 1 input (#628) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * correct bugs with gw barycenter on 1 input * merge --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 1 + ot/gromov/_bregman.py | 19 ++++++--- ot/gromov/_gw.py | 20 +++++++--- test/gromov/test_bregman.py | 75 ++++++++++++++++++++++++++++++++++++ test/gromov/test_gw.py | 77 ++++++++++++++++++++++++++++++++++++- 5 files changed, 180 insertions(+), 12 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 3908d079c..e6e6ff4d4 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,7 @@ - Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610) - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) +- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) ## 0.9.3 *January 2024* diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index df4ba0ae3..6bb7a675a 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -735,10 +735,15 @@ def entropic_gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) + if isinstance(Cs[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: @@ -1620,11 +1625,15 @@ def entropic_fused_gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) - Ys = list_to_array(*Ys) + if isinstance(Cs[0], list) or isinstance(Ys[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs, *Ys] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 46e1ddfe8..9dbc6b19e 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -808,13 +808,19 @@ def gromov_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) + if isinstance(Cs[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: + arr.append(list_to_array(p)) else: p = unif(N, type_as=Cs[0]) @@ -1014,11 +1020,15 @@ def fgw_barycenters( if stop_criterion not in ['barycenter', 'loss']: raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") - Cs = list_to_array(*Cs) - Ys = list_to_array(*Ys) + if isinstance(Cs[0], list) or isinstance(Ys[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + arr = [*Cs, *Ys] if ps is not None: - arr += list_to_array(*ps) + if isinstance(ps[0], list): + raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + + arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: diff --git a/test/gromov/test_bregman.py b/test/gromov/test_bregman.py index 4baf3ce10..71e55b1ce 100644 --- a/test/gromov/test_bregman.py +++ b/test/gromov/test_bregman.py @@ -792,6 +792,46 @@ def test_entropic_fgw_barycenter(nx): np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb, init_Yb) + # test edge cases for fgw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1b] + _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb], [C1_list], [p1b], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=None, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1b) + p2_list = list(p2b) + _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], lambdas=[0.5, 0.5], + fixed_structure=False, fixed_features=False, + init_Y=None, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + # unique input structure + X, C = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys], [C1], [p1], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=init_Y, p=p, max_iter=10, tol=1e-3, + warmstartT=True, log=False, random_state=98765, verbose=True + ) + + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb], [C1b], [p1b], lambdas=None, + fixed_structure=False, fixed_features=False, + init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, + warmstartT=True, log=False, random_state=98765, verbose=True + ) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(X, Xb, atol=1e-06) + @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(nx): @@ -886,6 +926,41 @@ def test_gromov_entropic_barycenter(nx): np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + # test edge cases for gw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1b] + _, _ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1_list], [p1b], pb, None, 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_C=None, log=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1b) + p2_list = list(p2b) + _, _ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1_list, p2_list], pb, None, + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, + verbose=True, random_state=42, init_Cb=None, log=True + ) + + # unique input structure + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1], [p1], p, None, 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, + init_C=None, log=False) + + Cbb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b], [p1b], pb, [1.], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_Cb=None, log=False + ) + + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + def test_not_implemented_solver(): # test sinkhorn diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 0008cebce..e76a33dcf 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -429,6 +429,40 @@ def test_gromov_barycenter(nx): np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + # test edge cases for gw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1_list = [list(c) for c in C1] + _ = ot.gromov.gromov_barycenters( + n_samples, [C1_list], None, p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1) + p2_list = list(p2) + _ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1_list, p2_list], p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # unique input structure + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1], None, p, None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b], None, None, [1.], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + def test_fgw(nx): n_samples = 20 # nb samples @@ -815,7 +849,7 @@ def test_fgw_barycenter(nx): X, C, log = ot.gromov.fgw_barycenters( n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, init_X=X, warmstartT=True, random_state=12345, log=True ) @@ -823,7 +857,7 @@ def test_fgw_barycenter(nx): X, C, log = ot.gromov.fgw_barycenters( n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True ) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) @@ -832,3 +866,42 @@ def test_fgw_barycenter(nx): # test correspondance with utils function recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) np.testing.assert_allclose(C, recovered_C) + + # test edge cases for fgw barycenters: + # C1 as list + with pytest.raises(ValueError): + C1b_list = [list(c) for c in C1b] + _, _, _ = ot.gromov.fgw_barycenters( + n_samples, [ysb], [C1b_list], [p1b], None, 0.5, + fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, + init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + ) + + # p1, p2 as lists + with pytest.raises(ValueError): + p1_list = list(p1) + p2_list = list(p2) + _, _, _ = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], None, 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, + init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + ) + + # unique input structure + X, C = ot.gromov.fgw_barycenters( + n_samples, [ys], [C1], [p1], None, 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb], [C1b], [p1b], [1.], 0.5, + fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(X, Xb, atol=1e-06) From 628a089da619a0af9f71f78faf762e344075a487 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Fri, 21 Jun 2024 03:25:24 -0400 Subject: [PATCH 20/30] [CI] Add x86 based macOS to testing (#636) * Use the 'macos-latest' runs-on option which now defaults to 'macos-14' which are Apple silicon runners. - c.f. https://github.com/actions/runner-images/blob/d31241d2c8c646cd0ba18579836636e62c38f1c9/images/macos/macos-14-arm64-Readme.md * Keep a 'macos-13' runner to continue to test x86 based macOS for the latest Python version. - c.f. https://github.com/actions/runner-images/blob/d31241d2c8c646cd0ba18579836636e62c38f1c9/images/macos/macos-13-Readme.md --- .github/workflows/build_tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 9bdd337c0..4805d02d0 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -88,11 +88,12 @@ jobs: uses: codecov/codecov-action@v3 macos: - runs-on: macos-latest + runs-on: ${{ matrix.os }} if: "!contains(github.event.head_commit.message, 'no ci')" strategy: max-parallel: 4 matrix: + os: [macos-latest, macos-13] python-version: ["3.11"] steps: From 14c08ba4859650146e8b621560c43ff777daa895 Mon Sep 17 00:00:00 2001 From: Yikun Bai <31041246+yikun-baio@users.noreply.github.com> Date: Fri, 21 Jun 2024 04:55:26 -0500 Subject: [PATCH 21/30] [MRG] Fix Gradient scaling in Partial GW solver (#602) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * new file: ot/partial_gw.py * remove partial_gw.py to update existing file partial.py * fix pep8 --------- Co-authored-by: Rémi Flamary Co-authored-by: Cédric Vincent-Cuaz --- RELEASES.md | 1 + ot/partial.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index e6e6ff4d4..77863b640 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -19,6 +19,7 @@ - Fix same sign error for sr(F)GW conditional gradient solvers (PR #611) - Split `test/test_gromov.py` into `test/gromov/` (PR #619) - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) +- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) ## 0.9.3 *January 2024* diff --git a/ot/partial.py b/ot/partial.py index 85635c9ba..a3b25a856 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -4,12 +4,15 @@ """ # Author: Laetitia Chapel -# License: MIT License +# Yikun Bai < yikun.bai@vanderbilt.edu > +# Cédric Vincent-Cuaz -import numpy as np -from .lp import emd -from .backend import get_backend from .utils import list_to_array +from .backend import get_backend +from .lp import emd +import numpy as np + +# License: MIT License def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, " equal than min(|a|_1, |b|_1).") if G0 is None: - G0 = np.outer(p, q) + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies) q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) @@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Gprev = np.copy(G0) - M = gwgrad_partial(C1, C2, G0) + M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 From aa4f370c603c27778e8a2edd38ed3591e0ed4396 Mon Sep 17 00:00:00 2001 From: Sarah Gammon <91751417+SarahG-579462@users.noreply.github.com> Date: Fri, 21 Jun 2024 11:57:51 -0400 Subject: [PATCH 22/30] Add optional install options with pip (#627) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add extra optional dependency options * check for dependencies, more informative errors: da, dr * update RELEASES.md * update README.md * Update setup.py Co-authored-by: Rémi Flamary * change filename to requirements_all in all files * woops some of those were for docs --------- Co-authored-by: Rémi Flamary --- .circleci/config.yml | 2 +- .github/workflows/build_doc.yml | 2 +- .github/workflows/build_tests.yml | 4 ++-- README.md | 6 ++++++ RELEASES.md | 1 + ot/da.py | 6 ++++-- ot/dr.py | 15 +++++++++------ requirements.txt => requirements_all.txt | 0 setup.py | 14 ++++++++++++++ 9 files changed, 38 insertions(+), 12 deletions(-) rename requirements.txt => requirements_all.txt (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 00cd876cf..9a835953b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,7 +46,7 @@ jobs: command: | python -m pip install --user --upgrade --progress-bar off pip python -m pip install --user -e . - python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt + python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements_all.txt python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler # python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index 3af2d301f..eed5643c5 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -24,7 +24,7 @@ jobs: - name: Get Python running run: | python -m pip install --user --upgrade --progress-bar off pip - python -m pip install --user --upgrade --progress-bar off -r requirements.txt + python -m pip install --user --upgrade --progress-bar off -r requirements_all.txt python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler python -m pip install --user -e . diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 4805d02d0..b42dd2b0f 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -r requirements_all.txt pip install pytest pytest-cov - name: Run tests run: | @@ -108,7 +108,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -r requirements_all.txt pip install pytest - name: Run tests run: | diff --git a/README.md b/README.md index 1cd9fb59b..7f2ce3ee3 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,12 @@ or get the very latest version by running: pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root) ``` +Optional dependencies may be installed with +```console +pip install POT[all] +``` +Note that this installs `cvxopt`, which is licensed under GPL 3.0. Alternatively, if you cannot use GPL-licensed software, the specific optional dependencies may be installed individually, or per-submodule. The available optional installations are `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. + #### Anaconda installation with conda-forge If you use the Anaconda python distribution, POT is available in [conda-forge](https://conda-forge.org). To install it and the required dependencies: diff --git a/RELEASES.md b/RELEASES.md index 77863b640..ebfa07a06 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ + New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) + Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605). + Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614) ++ Optional dependencies may now be installed with `pip install POT[all]` The specific backends or submodules' dependencies may also be installed individually. The pip options are: `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. The installation of the `cupy` backend should be done with conda. #### Closed issues - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) diff --git a/ot/da.py b/ot/da.py index 2b28260ef..d6c55b6c2 100644 --- a/ot/da.py +++ b/ot/da.py @@ -376,8 +376,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al elif sim == 'knn': if sim_param is None: sim_param = 3 - - from sklearn.neighbors import kneighbors_graph + try: + from sklearn.neighbors import kneighbors_graph + except ImportError: + raise ValueError('scikit-learn must be installed to use knn similarity. Install with `$pip install scikit-learn`.') sS = nx.from_numpy(kneighbors_graph( X=nx.to_numpy(xs), n_neighbors=int(sim_param) diff --git a/ot/dr.py b/ot/dr.py index c56170209..e410ff837 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -17,12 +17,15 @@ # License: MIT License from scipy import linalg -import autograd.numpy as np -from sklearn.decomposition import PCA - -import pymanopt -import pymanopt.manifolds -import pymanopt.optimizers +try: + import autograd.numpy as np + from sklearn.decomposition import PCA + + import pymanopt + import pymanopt.manifolds + import pymanopt.optimizers +except ImportError: + raise ImportError("Missing dependency for ot.dr. Requires autograd, pymanopt, scikit-learn. You can install with install with 'pip install POT[dr]', or 'conda install autograd pymanopt scikit-learn'") from .bregman import sinkhorn as sinkhorn_bregman from .utils import dist as dist_utils, check_random_state diff --git a/requirements.txt b/requirements_all.txt similarity index 100% rename from requirements.txt rename to requirements_all.txt diff --git a/setup.py b/setup.py index 313d25863..dea4da670 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,9 @@ sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path']) os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) +with open('requirements_all.txt') as f: + optional_requirements = f.read().splitlines() + setup( name='POT', version=__version__, @@ -70,6 +73,17 @@ scripts=[], data_files=[], install_requires=["numpy>=1.16", "scipy>=1.6"], + extras_require={ + 'backend-numpy': [], # in requirements. + 'backend-jax': ['jax<=0.4.24', 'jaxlib<=0.4.24'], + 'backend-cupy': [], # should be installed with conda, not pip, or figure out what CUDA version above. + 'backend-tf': ['tensorflow'], + 'backend-torch': ['torch'], + 'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations + 'dr': ['scikit-learn', 'pymanopt', 'autograd'], + 'gnn': ['torch', 'torch_geometric'], + 'all': optional_requirements + }, python_requires=">=3.7", classifiers=[ 'Development Status :: 5 - Production/Stable', From f8f298d0b826a0532d65e716ef040719d62f0caa Mon Sep 17 00:00:00 2001 From: Simon Forbat <54510453+simon-forb@users.noreply.github.com> Date: Sat, 22 Jun 2024 20:56:24 +0200 Subject: [PATCH 23/30] Update _gw.py (#637) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit doc typos Co-authored-by: Cédric Vincent-Cuaz --- ot/gromov/_gw.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 9dbc6b19e..86ff566ea 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -43,11 +43,11 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric Where : - - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{p}`: distribution in the source space - - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space. + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space. + - :math:`\mathbf{p}`: Distribution in the source space. + - :math:`\mathbf{q}`: Distribution in the target space. + - `L`: Loss function to account for the misfit between the similarity matrices. .. note:: This function is backend-compatible and will work on arrays from all compatible backends. But the algorithm uses the C++ CPU backend @@ -62,9 +62,9 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric Parameters ---------- C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space + Metric cost matrix in the source space. C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space + Metric cost matrix in the target space. p : array-like, shape (ns,), optional Distribution in the source space. If let to its default value None, uniform distribution is taken. @@ -72,29 +72,29 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric Distribution in the target space. If let to its default value None, uniform distribution is taken. loss_fun : str, optional - loss function used for the solver either 'square_loss' or 'kl_loss' + Loss function used for the solver either 'square_loss' or 'kl_loss'. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). verbose : bool, optional - Print information along iterations + Print information along iterations. log : bool, optional - record log if True + Record log if True. armijo : bool, optional - If True the step of the line-search is found via an armijo research. Else closed form is used. - If there are convergence issues use False. + If True, the step of the line-search is found via an armijo search. Else closed form is used. + If there are convergence issues, use False. G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. + If None, the initial transport plan of the solver is pq^T. Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. max_iter : int, optional - Max number of iterations + Max number of iterations. tol_rel : float, optional - Stop threshold on relative error (>0) + Stop threshold on relative error (>0). tol_abs : float, optional - Stop threshold on absolute error (>0) + Stop threshold on absolute error (>0). **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver + Parameters can be directly passed to the ot.optim.cg solver. Returns ------- @@ -175,7 +175,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): if not nx.is_floating_point(C10): warnings.warn( - "Input structure matrix consists of integer. The transport plan will be " + "Input structure matrix consists of integers. The transport plan will be " "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "structure matrix consists of floating point elements.", From a8f0ed5b2b883472f9770f63189f11d03d06a346 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 24 Jun 2024 09:17:08 -0400 Subject: [PATCH 24/30] [MAINT] Deprecate distutils in favor of setuptools (#635) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MAINT] Deprecate distutils in favor of setuptools * In CPython 3.10 distutils was formally marked as deprecated and was removed in Python 3.12 (https://peps.python.org/pep-0632/). Switch to using setuptools as a replacement, which offers a direct port of 'distutils.errors' as 'setuptools.errors'. * [CI] Ensure setuptools installed for tests * To ensure that setuptools is available for setuptools.errors install it for all the tests in the CI. - setuptools.errors was added in v60.0.0 - c.f. https://setuptools.pypa.io/en/latest/deprecated/distutils-legacy.html --------- Co-authored-by: Rémi Flamary --- .github/workflows/build_tests.yml | 8 ++++---- .github/workflows/build_tests_cuda.yml | 3 ++- .github/workflows/build_wheels.yml | 4 ++-- .github/workflows/build_wheels_weekly.yml | 2 +- ot/helpers/openmp_helpers.py | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index b42dd2b0f..cd45633de 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -35,7 +35,7 @@ jobs: pip install -e . - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools pip install -r requirements_all.txt pip install pytest pytest-cov - name: Run tests @@ -55,7 +55,7 @@ jobs: python-version: "3.10" - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools pip install flake8 - name: Lint with flake8 run: | @@ -76,7 +76,7 @@ jobs: python-version: "3.10" - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools pip install pytest pytest-cov - name: Install POT run: | @@ -107,7 +107,7 @@ jobs: pip install -e . - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools pip install -r requirements_all.txt pip install pytest - name: Run tests diff --git a/.github/workflows/build_tests_cuda.yml b/.github/workflows/build_tests_cuda.yml index be8e47c8b..2d4e452c5 100644 --- a/.github/workflows/build_tests_cuda.yml +++ b/.github/workflows/build_tests_cuda.yml @@ -15,7 +15,8 @@ jobs: - uses: actions/checkout@v4 - name: Install POT run: | - python3.10 -m pip install --ignore-installed -e . + python3.10 -m pip install --upgrade pip setuptools + python3.10 -m pip install --ignore-installed -e . - name: Run tests run: | python3.10 -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --ignore=test/test_dr.py --ignore=ot.dr --ignore=ot.plot diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index c60babff6..ca3824975 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools - name: Install cibuildwheel run: | @@ -61,7 +61,7 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools - name: Install cibuildwheel run: | diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index 4746a3676..f528c2dce 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools - name: Install cibuildwheel run: | diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py index a6ad38b06..90a2918da 100644 --- a/ot/helpers/openmp_helpers.py +++ b/ot/helpers/openmp_helpers.py @@ -9,7 +9,7 @@ import textwrap import subprocess -from distutils.errors import CompileError, LinkError +from setuptools.errors import CompileError, LinkError from pre_build_helpers import compile_test_program From 2941ed376a926aeec7ad19fc9161a649a94f0e87 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 24 Jun 2024 17:15:30 +0200 Subject: [PATCH 25/30] [DA] Sinkhorn LpL1 transport to work on JAX (#592) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Draft implementation for per-class regularization in lpl1 * Do not use assignment to replace non finite elements * Make vectorize version of lpl1 work * Proper lpl1 vectorization * Remove type error test for JAX (should work now) * Update test, coupling still has nans * Explicitly check for nans in the coupling return from sinkhorn * fix small comments --------- Co-authored-by: Rémi Flamary Co-authored-by: Cédric Vincent-Cuaz --- RELEASES.md | 1 + ot/da.py | 33 ++++++++++++-------------- test/test_da.py | 62 +++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 68 insertions(+), 28 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index ebfa07a06..55f5b0b17 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,6 +21,7 @@ - Split `test/test_gromov.py` into `test/gromov/` (PR #619) - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) +- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) ## 0.9.3 *January 2024* diff --git a/ot/da.py b/ot/da.py index d6c55b6c2..b51b08b3a 100644 --- a/ot/da.py +++ b/ot/da.py @@ -122,14 +122,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, p = 0.5 epsilon = 1e-3 - indices_labels = [] - classes = nx.unique(labels_a) - for c in classes: - idxc, = nx.where(labels_a == c) - indices_labels.append(idxc) + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=M)[labels_idx] W = nx.zeros(M.shape, type_as=M) - for cpt in range(numItermax): + for _ in range(numItermax): Mreg = M + eta * W if log: transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, @@ -137,13 +135,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, else: transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) - # the transport has been computed. Check if classes are really - # separated - W = nx.ones(M.shape, type_as=M) - for (i, c) in enumerate(classes): - majs = nx.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon) ** (p - 1)) - W[indices_labels[i]] = majs + # the transport has been computed + # check if classes are really separated + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) + W = nx.dot(W, unroll_labels_idx.T) + W = p * ((W.T + epsilon) ** (p - 1)) if log: return transp, log @@ -1925,7 +1922,7 @@ def transform(self, Xs): transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs = nx.dot(transp, self.xt_) @@ -2214,7 +2211,7 @@ class label transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute transported samples transp_Xs.append(nx.dot(transp, self.xt_)) @@ -2238,7 +2235,7 @@ class label # transport the source samples for coupling in self.coupling_: transp = coupling / nx.sum(coupling, 1)[:, None] - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) transp_Xs_.append(nx.dot(transp, self.xt_)) transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) @@ -2291,7 +2288,7 @@ def transform_labels(self, ys=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) if self.log: D1 = self.log_['D1'][i] @@ -2339,7 +2336,7 @@ def inverse_transform_labels(self, yt=None): transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ nx.isfinite(transp)] = 0 + transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) # compute propagated labels transp_ys.append(nx.dot(D1, transp.T).T) diff --git a/test/test_da.py b/test/test_da.py index 0e51bda22..d3c343242 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -28,10 +28,9 @@ def test_class_jax_tf(): + from ot.backend import tf + backends = [] - from ot.backend import jax, tf - if jax: - backends.append(ot.backend.JaxBackend()) if tf: backends.append(ot.backend.TensorflowBackend()) @@ -70,7 +69,6 @@ def test_log_da(nx, class_to_test): assert hasattr(otda, "log_") -@pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport @@ -79,10 +77,13 @@ def test_sinkhorn_lpl1_transport_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif('3gauss', ns, random_state=42) + Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + # prepare semi-supervised labels + yt_semi = np.copy(yt) + yt_semi[np.arange(0, nt, 2)] = -1 - Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi) otda = ot.da.SinkhornLpl1Transport() @@ -109,7 +110,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -119,7 +120,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -142,10 +143,12 @@ def test_sinkhorn_lpl1_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) + assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) + assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite" assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) @@ -944,3 +947,42 @@ def df2(G): assert np.allclose(f(G), f2(G)) assert np.allclose(df(G), df2(G)) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_vectorization(nx): + n_samples, n_labels = 150, 3 + rng = np.random.RandomState(42) + M = rng.rand(n_samples, n_samples) + labels_a = rng.randint(n_labels, size=(n_samples,)) + M, labels_a = nx.from_numpy(M), nx.from_numpy(labels_a) + + # hard-coded params from the original code + p, epsilon = 0.5, 1e-3 + T = nx.from_numpy(rng.rand(n_samples, n_samples)) + + def unvectorized(transp): + indices_labels = [] + classes = nx.unique(labels_a) + for c in classes: + idxc, = nx.where(labels_a == c) + indices_labels.append(idxc) + W = nx.ones(M.shape, type_as=M) + for (i, c) in enumerate(classes): + majs = nx.sum(transp[indices_labels[i]], axis=0) + majs = p * ((majs + epsilon) ** (p - 1)) + W[indices_labels[i]] = majs + return W + + def vectorized(transp): + labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) + n_labels = labels_u.shape[0] + unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx] + W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = nx.sum(W, axis=1) + W = p * ((W + epsilon) ** (p - 1)) + W = nx.dot(W, unroll_labels_idx.T) + return W.T + + assert np.allclose(unvectorized(T), vectorized(T)) From a19566013192bae54e91b4f6403629800258ed4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Jun 2024 10:30:11 +0200 Subject: [PATCH 26/30] [Fix] import linesearch for modern scipy (#642) * update import for scalar armijo * add release line in file --- RELEASES.md | 1 + ot/optim.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 55f5b0b17..56cb6fd4b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -22,6 +22,7 @@ - Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628) - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) - Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) +- Fiw linesearch import error on Scipy 1.14 (PR #642, Issue #641) ## 0.9.3 *January 2024* diff --git a/ot/optim.py b/ot/optim.py index dcdef6a88..fe5eda821 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -19,7 +19,7 @@ try: from scipy.optimize import scalar_search_armijo except ImportError: - from scipy.optimize.linesearch import scalar_search_armijo + from scipy.optimize._linesearch import scalar_search_armijo # The corresponding scipy function does not work for matrices From 3ea24a36639b168bab045b899767b1a82ddbdce7 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Tue, 25 Jun 2024 07:35:20 -0500 Subject: [PATCH 27/30] [CI] Use Python 3.12 as default for tests (#632) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CI] Use Python 3.12 as default for tests * Add Python 3.12 to the Linux test matrix. * Use Python 3.12 as the default Python for all other CI tests. * [CI] Remove jax upper bounds for testing * reset caps on jax * Update requirements_all.txt * Update setup.py * Update setup.py * Update requirements_all.txt * Update requirements_all.txt * Update requirements_all.txt --------- Co-authored-by: Cédric Vincent-Cuaz --- .github/workflows/build_tests.yml | 22 +++++++++++----------- requirements_all.txt | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index cd45633de..a265c79ec 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -4,10 +4,10 @@ on: workflow_dispatch: pull_request: branches: - - 'master' + - 'master' push: branches: - - 'master' + - 'master' create: branches: - 'master' @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -49,14 +49,14 @@ jobs: if: "!contains(github.event.head_commit.message, 'no pep8')" steps: - uses: actions/checkout@v4 - - name: Set up Python + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.x" - name: Install dependencies run: | python -m pip install --upgrade pip setuptools - pip install flake8 + pip install flake8 - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -70,10 +70,10 @@ jobs: if: "!contains(github.event.head_commit.message, 'no ci')" steps: - uses: actions/checkout@v4 - - name: Set up Python + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip setuptools @@ -94,7 +94,7 @@ jobs: max-parallel: 4 matrix: os: [macos-latest, macos-13] - python-version: ["3.11"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 @@ -109,7 +109,7 @@ jobs: run: | python -m pip install --upgrade pip setuptools pip install -r requirements_all.txt - pip install pytest + pip install pytest - name: Run tests run: | python -m pytest --durations=20 -v test/ ot/ --color=yes @@ -121,7 +121,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.11"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 diff --git a/requirements_all.txt b/requirements_all.txt index bffaf892f..66a7c2dfc 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2,7 +2,7 @@ numpy>=1.20 scipy>=1.6 matplotlib autograd -pymanopt +pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master cvxopt scikit-learn torch From 14bbebf1b6c5a3acef91e07247388945fd7a9014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 25 Jun 2024 17:09:03 +0200 Subject: [PATCH 28/30] [WIP] Upgrade JAX support from 0.4.24 to 0.4.30 (#643) * fix jax device * fix jax device * update requirements * up * fix? * fix? * fix * complete fix * complete fix --- RELEASES.md | 1 + ot/backend.py | 16 ++++++++++++++-- requirements_all.txt | 4 ++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 56cb6fd4b..294614114 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -23,6 +23,7 @@ - Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602) - Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592) - Fiw linesearch import error on Scipy 1.14 (PR #642, Issue #641) +- Upgrade supported JAX versions from jax<=0.4.24 to jax<=0.4.30 (PR #643) ## 0.9.3 *January 2024* diff --git a/ot/backend.py b/ot/backend.py index 9cc6446bf..534c03293 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -120,6 +120,7 @@ import jax.scipy.special as jspecial from jax.lib import xla_bridge jax_type = jax.numpy.ndarray + jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24 except ImportError: jax = False jax_type = float @@ -1439,11 +1440,19 @@ def __init__(self): jax.device_put(jnp.array(1, dtype=jnp.float64), d) ] + self.jax_new_version = jax_new_version + def _to_numpy(self, a): return np.array(a) + def _get_device(self, a): + if self.jax_new_version: + return list(a.devices())[0] + else: + return a.device_buffer.device() + def _change_device(self, a, type_as): - return jax.device_put(a, type_as.device_buffer.device()) + return jax.device_put(a, self._get_device(type_as)) def _from_numpy(self, a, type_as=None): if isinstance(a, float): @@ -1688,7 +1697,10 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def dtype_device(self, a): - return a.dtype, a.device_buffer.device() + if self.jax_new_version: + return a.dtype, list(a.devices())[0] + else: + return a.dtype, a.device_buffer.device() def assert_same_dtype_device(self, a, b): a_dtype, a_device = self.dtype_device(a) diff --git a/requirements_all.txt b/requirements_all.txt index 66a7c2dfc..a015855f6 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -6,8 +6,8 @@ pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master cvxopt scikit-learn torch -jax<=0.4.24 -jaxlib<=0.4.24 +jax +jaxlib tensorflow pytest torch_geometric From 5c9c70a6b5d2c4831d186e0702f66c569b7f5b2e Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 26 Jun 2024 01:43:50 -0500 Subject: [PATCH 29/30] [FIX] Default to using scipy v1.8.0+ API (#646) * As most users will be installing the latest scipy release when they install POT, default to using the scipy v1.8.0+ API and catch the ModuleNotFoundError in the rare case in which scipy v1.7.3 or older is required, such as Python 3.7. * Always import from the scipy.optimize.{_}linesearch API as the newest version of scipy that supports Python 3.7 is v1.7.3 which does not support 'from scipy.optimize import scalar_search_armijo', just as scipy v1.14.0+ does not. --- ot/optim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index fe5eda821..bde0fc814 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -17,9 +17,10 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") try: - from scipy.optimize import scalar_search_armijo - except ImportError: from scipy.optimize._linesearch import scalar_search_armijo + except ModuleNotFoundError: + # scipy<1.8.0 + from scipy.optimize.linesearch import scalar_search_armijo # The corresponding scipy function does not work for matrices From 2987765a2d520c53a2976ff6c35968eeb09cbee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 26 Jun 2024 11:36:33 +0200 Subject: [PATCH 30/30] [MRG] Release 0.9.4 (#640) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * proper number + fisrts shit release file * Update RELEASES.md Co-authored-by: Matthew Feickert * remove double sentence * remove jax constraints in setup.py * ad requiremet_all to manifest for a working source distibution * update setup.py with proiper optional install * small update reelase notes --------- Co-authored-by: Matthew Feickert Co-authored-by: Cédric Vincent-Cuaz --- MANIFEST.in | 1 + RELEASES.md | 8 +++++++- ot/__init__.py | 2 +- ot/gromov/_lowrank.py | 1 - setup.py | 16 +++++++++------- 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index da67c77fe..da2820cdc 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ include README.md include RELEASES.md include LICENSE +include requirements_all.txt include ot/lp/core.h include ot/lp/EMD.h include ot/lp/EMD_wrapper.cpp diff --git a/RELEASES.md b/RELEASES.md index 294614114..56de9a179 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,6 +1,12 @@ # Releases -## 0.9.4dev +## 0.9.4 +*June 2024* + +This new release contains several new features and bug fixes. Among the new features +we have novel [Quantized FGW](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) solvers that can be used to speed up the computation of the FGW loss on large datasets or to promote a structure on the pairwise matrices. We also updated the continuous entropic mapping to provide efficient out-of-sample continuous mapping thanks to entropic regularization. We also have a new general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example. Finally we have a new solver for the [Low Rank Gromov-Wasserstein](https://pythonot.github.io/auto_examples/others/plot_lowrank_GW.html) that can be used to compute the GW distance between two large scale datasets with a low rank approximation. + +From a maintenance point of view, we now have a new option to install optional dependencies with `pip install POT[all]` and the specific backends or submodules' dependencies may also be installed individually. The pip options are: `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, plot, all`. We also provide with this release support for NumPy 2.0 (the wheels should now be compatible with NumPy 2.0 and below). We also fixed several issues such as gradient sign errors for FGW solvers, empty weights for `ot.emd2`, and line-search in partial GW. We also split the `test/test_gromov.py` into `test/gromov/` to make the tests more manageable. #### New features + NumPy 2.0 support is added (PR #629) diff --git a/ot/__init__.py b/ot/__init__.py index d8ac5ac28..5eb3977aa 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -59,7 +59,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.4dev" +__version__ = "0.9.4" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/gromov/_lowrank.py b/ot/gromov/_lowrank.py index 5bab15edc..9aa3faab5 100644 --- a/ot/gromov/_lowrank.py +++ b/ot/gromov/_lowrank.py @@ -61,7 +61,6 @@ def _flat_product_operator(X, nx=None): def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, gamma_init="rescale", rescale_cost=True, cost_factorized_Xs=None, cost_factorized_Xt=None, stopThr=1e-4, numItermax=1000, stopThr_dykstra=1e-3, numItermax_dykstra=10000, seed_init=49, warn=True, warn_dykstra=False, log=False): - r""" Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints on the couplings and cost matrices. diff --git a/setup.py b/setup.py index dea4da670..c0f75ad0f 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ #!/usr/bin/env python +# Author: Remi Flamary +# +# License: MIT License import os import re @@ -46,8 +49,6 @@ sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path']) os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) -with open('requirements_all.txt') as f: - optional_requirements = f.read().splitlines() setup( name='POT', @@ -74,15 +75,16 @@ data_files=[], install_requires=["numpy>=1.16", "scipy>=1.6"], extras_require={ - 'backend-numpy': [], # in requirements. - 'backend-jax': ['jax<=0.4.24', 'jaxlib<=0.4.24'], - 'backend-cupy': [], # should be installed with conda, not pip, or figure out what CUDA version above. + 'backend-numpy': [], # in requirements. + 'backend-jax': ['jax', 'jaxlib'], + 'backend-cupy': [], # should be installed with conda, not pip 'backend-tf': ['tensorflow'], 'backend-torch': ['torch'], - 'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations + 'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations 'dr': ['scikit-learn', 'pymanopt', 'autograd'], 'gnn': ['torch', 'torch_geometric'], - 'all': optional_requirements + 'plot': ['matplotlib'], + 'all': ['jax', 'jaxlib', 'tensorflow', 'torch', 'cvxopt', 'scikit-learn', 'pymanopt', 'autograd', 'torch_geometric', 'matplotlib'] }, python_requires=">=3.7", classifiers=[