@@ -492,6 +492,121 @@ def get_coordinate_circle(x):
492492 return x_t
493493
494494
495+ def reduce_lazytensor (a , func , axis = None , nx = None , batch_size = 100 ):
496+ """ Reduce a LazyTensor along an axis with function fun using batches.
497+
498+ When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
499+ batches taken along dim.
500+
501+ .. warning::
502+ This function works for tensor of any order but the reduction can be done
503+ only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory.
504+
505+
506+ Parameters
507+ ----------
508+ a : LazyTensor
509+ LazyTensor to reduce
510+ func : callable
511+ Function to apply to the LazyTensor
512+ axis : int, optional
513+ Axis along which to reduce the LazyTensor. If None, reduce the
514+ LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
515+ If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
516+ batches taken along axis.
517+ nx : Backend, optional
518+ Backend to use for the reduction
519+ batch_size : int, optional
520+ Size of the batches to use for the reduction (default=100)
521+
522+ Returns
523+ -------
524+ res : array-like
525+ Result of the reduction
526+
527+ """
528+
529+ if nx is None :
530+ nx = get_backend (a [0 ])
531+
532+ if axis is None :
533+ res = 0.0
534+ for i in range (0 , a .shape [0 ], batch_size ):
535+ res += func (a [i :i + batch_size ])
536+ return res
537+ elif axis == 0 :
538+ res = nx .zeros (a .shape [1 :], type_as = a [0 ])
539+ if nx .__name__ in ["jax" , "tf" ]:
540+ lst = []
541+ for j in range (0 , a .shape [1 ], batch_size ):
542+ lst .append (func (a [:, j :j + batch_size ], 0 ))
543+ return nx .concatenate (lst , axis = 0 )
544+ else :
545+ for j in range (0 , a .shape [1 ], batch_size ):
546+ res [j :j + batch_size ] = func (a [:, j :j + batch_size ], axis = 0 )
547+ return res
548+ elif axis == 1 :
549+ if len (a .shape ) == 2 :
550+ shape = (a .shape [0 ])
551+ else :
552+ shape = (a .shape [0 ], * a .shape [2 :])
553+ res = nx .zeros (shape , type_as = a [0 ])
554+ if nx .__name__ in ["jax" , "tf" ]:
555+ lst = []
556+ for i in range (0 , a .shape [0 ], batch_size ):
557+ lst .append (func (a [i :i + batch_size ], 1 ))
558+ return nx .concatenate (lst , axis = 0 )
559+ else :
560+ for i in range (0 , a .shape [0 ], batch_size ):
561+ res [i :i + batch_size ] = func (a [i :i + batch_size ], axis = 1 )
562+ return res
563+
564+ else :
565+ raise (NotImplementedError ("Only axis=None, 0 or 1 is implemented for now." ))
566+
567+
568+ def get_lowrank_lazytensor (Q , R , d = None , nx = None ):
569+ """ Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T
570+
571+ Parameters
572+ ----------
573+ Q : ndarray, shape (n, r)
574+ First factor of the lowrank tensor
575+ R : ndarray, shape (m, r)
576+ Second factor of the lowrank tensor
577+ d : ndarray, shape (r,), optional
578+ Diagonal of the lowrank tensor
579+ nx : Backend, optional
580+ Backend to use for the reduction
581+
582+ Returns
583+ -------
584+ T : LazyTensor
585+ Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T
586+ """
587+
588+ if nx is None :
589+ nx = get_backend (Q , R , d )
590+
591+ shape = (Q .shape [0 ], R .shape [0 ])
592+
593+ if d is None :
594+
595+ def func (i , j , Q , R ):
596+ return nx .dot (Q [i ], R [j ].T )
597+
598+ T = LazyTensor (shape , func , Q = Q , R = R )
599+
600+ else :
601+
602+ def func (i , j , Q , R , d ):
603+ return nx .dot (Q [i ] * d [None , :], R [j ].T )
604+
605+ T = LazyTensor (shape , func , Q = Q , R = R , d = d )
606+
607+ return T
608+
609+
495610def get_parameter_pair (parameter ):
496611 r"""Extract a pair of parameters from a given parameter
497612 Used in unbalanced OT and COOT solvers
@@ -761,7 +876,76 @@ class UndefinedParameter(Exception):
761876
762877
763878class OTResult :
764- def __init__ (self , potentials = None , value = None , value_linear = None , value_quad = None , plan = None , log = None , backend = None , sparse_plan = None , lazy_plan = None , status = None ):
879+ """ Base class for OT results.
880+
881+ Parameters
882+ ----------
883+
884+ potentials : tuple of array-like, shape (`n1`, `n2`)
885+ Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
886+ This pair of arrays has the same shape, numerical type
887+ and properties as the input weights "a" and "b".
888+ value : float, array-like
889+ Full transport cost, including possible regularization terms and
890+ quadratic term for Gromov Wasserstein solutions.
891+ value_linear : float, array-like
892+ The linear part of the transport cost, i.e. the product between the
893+ transport plan and the cost.
894+ value_quad : float, array-like
895+ The quadratic part of the transport cost for Gromov-Wasserstein
896+ solutions.
897+ plan : array-like, shape (`n1`, `n2`)
898+ Transport plan, encoded as a dense array.
899+ log : dict
900+ Dictionary containing potential information about the solver.
901+ backend : Backend
902+ Backend used to compute the results.
903+ sparse_plan : array-like, shape (`n1`, `n2`)
904+ Transport plan, encoded as a sparse array.
905+ lazy_plan : LazyTensor
906+ Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
907+ status : int or str
908+ Status of the solver.
909+ batch_size : int
910+ Batch size used to compute the results/marginals for LazyTensor.
911+
912+ Attributes
913+ ----------
914+
915+ potentials : tuple of array-like, shape (`n1`, `n2`)
916+ Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
917+ This pair of arrays has the same shape, numerical type
918+ and properties as the input weights "a" and "b".
919+ potential_a : array-like, shape (`n1`,)
920+ First dual potential, associated to the "source" measure "a".
921+ potential_b : array-like, shape (`n2`,)
922+ Second dual potential, associated to the "target" measure "b".
923+ value : float, array-like
924+ Full transport cost, including possible regularization terms and
925+ quadratic term for Gromov Wasserstein solutions.
926+ value_linear : float, array-like
927+ The linear part of the transport cost, i.e. the product between the
928+ transport plan and the cost.
929+ value_quad : float, array-like
930+ The quadratic part of the transport cost for Gromov-Wasserstein
931+ solutions.
932+ plan : array-like, shape (`n1`, `n2`)
933+ Transport plan, encoded as a dense array.
934+ sparse_plan : array-like, shape (`n1`, `n2`)
935+ Transport plan, encoded as a sparse array.
936+ lazy_plan : LazyTensor
937+ Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
938+ marginals : tuple of array-like, shape (`n1`,), (`n2`,)
939+ Marginals of the transport plan: should be very close to "a" and "b"
940+ for balanced OT.
941+ marginal_a : array-like, shape (`n1`,)
942+ Marginal of the transport plan for the "source" measure "a".
943+ marginal_b : array-like, shape (`n2`,)
944+ Marginal of the transport plan for the "target" measure "b".
945+
946+ """
947+
948+ def __init__ (self , potentials = None , value = None , value_linear = None , value_quad = None , plan = None , log = None , backend = None , sparse_plan = None , lazy_plan = None , status = None , batch_size = 100 ):
765949
766950 self ._potentials = potentials
767951 self ._value = value
@@ -773,6 +957,7 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No
773957 self ._lazy_plan = lazy_plan
774958 self ._backend = backend if backend is not None else NumpyBackend ()
775959 self ._status = status
960+ self ._batch_size = batch_size
776961
777962 # I assume that other solvers may return directly
778963 # some primal objects?
@@ -793,7 +978,8 @@ def __repr__(self):
793978 s += 'value_linear={},' .format (self ._value_linear )
794979 if self ._plan is not None :
795980 s += 'plan={}(shape={}),' .format (self ._plan .__class__ .__name__ , self ._plan .shape )
796-
981+ if self ._lazy_plan is not None :
982+ s += 'lazy_plan={}(shape={}),' .format (self ._lazy_plan .__class__ .__name__ , self ._lazy_plan .shape )
797983 if s [- 1 ] != '(' :
798984 s = s [:- 1 ] + ')'
799985 else :
@@ -853,7 +1039,10 @@ def sparse_plan(self):
8531039 @property
8541040 def lazy_plan (self ):
8551041 """Transport plan, encoded as a symbolic KeOps LazyTensor."""
856- raise NotImplementedError ()
1042+ if self ._lazy_plan is not None :
1043+ return self ._lazy_plan
1044+ else :
1045+ raise NotImplementedError ()
8571046
8581047 # Loss values --------------------------------
8591048
@@ -897,6 +1086,11 @@ def marginal_a(self):
8971086 """First marginal of the transport plan, with the same shape as "a"."""
8981087 if self ._plan is not None :
8991088 return self ._backend .sum (self ._plan , 1 )
1089+ elif self ._lazy_plan is not None :
1090+ lp = self ._lazy_plan
1091+ bs = self ._batch_size
1092+ nx = self ._backend
1093+ return reduce_lazytensor (lp , nx .sum , axis = 1 , nx = nx , batch_size = bs )
9001094 else :
9011095 raise NotImplementedError ()
9021096
@@ -905,6 +1099,11 @@ def marginal_b(self):
9051099 """Second marginal of the transport plan, with the same shape as "b"."""
9061100 if self ._plan is not None :
9071101 return self ._backend .sum (self ._plan , 0 )
1102+ elif self ._lazy_plan is not None :
1103+ lp = self ._lazy_plan
1104+ bs = self ._batch_size
1105+ nx = self ._backend
1106+ return reduce_lazytensor (lp , nx .sum , axis = 0 , nx = nx , batch_size = bs )
9081107 else :
9091108 raise NotImplementedError ()
9101109
@@ -968,3 +1167,70 @@ def citation(self):
9681167 url = {http://jmlr.org/papers/v22/20-451.html}
9691168 }
9701169 """
1170+
1171+
1172+ class LazyTensor (object ):
1173+ """ A lazy tensor is a tensor that is not stored in memory. Instead, it is
1174+ defined by a function that computes its values on the fly from slices.
1175+
1176+ Parameters
1177+ ----------
1178+
1179+ shape : tuple
1180+ shape of the tensor
1181+ getitem : callable
1182+ function that computes the values of the indices/slices and tensors
1183+ as arguments
1184+
1185+ kwargs : dict
1186+ named arguments for the function, those names will be used as attributed
1187+ of the LazyTensor object
1188+
1189+ Examples
1190+ --------
1191+ >>> import numpy as np
1192+ >>> v = np.arange(5)
1193+ >>> def getitem(i,j, v):
1194+ ... return v[i,None]+v[None,j]
1195+ >>> T = LazyTensor((5,5),getitem, v=v)
1196+ >>> T[1,2]
1197+ array([3])
1198+ >>> T[1,:]
1199+ array([[1, 2, 3, 4, 5]])
1200+ >>> T[:]
1201+ array([[0, 1, 2, 3, 4],
1202+ [1, 2, 3, 4, 5],
1203+ [2, 3, 4, 5, 6],
1204+ [3, 4, 5, 6, 7],
1205+ [4, 5, 6, 7, 8]])
1206+
1207+ """
1208+
1209+ def __init__ (self , shape , getitem , ** kwargs ):
1210+
1211+ self ._getitem = getitem
1212+ self .shape = shape
1213+ self .ndim = len (shape )
1214+ self .kwargs = kwargs
1215+
1216+ # set attributes for named arguments/arrays
1217+ for key , value in kwargs .items ():
1218+ setattr (self , key , value )
1219+
1220+ def __getitem__ (self , key ):
1221+ k = []
1222+ if isinstance (key , int ) or isinstance (key , slice ):
1223+ k .append (key )
1224+ for i in range (self .ndim - 1 ):
1225+ k .append (slice (None ))
1226+ elif isinstance (key , tuple ):
1227+ k = list (key )
1228+ for i in range (self .ndim - len (key )):
1229+ k .append (slice (None ))
1230+ else :
1231+ raise NotImplementedError ("Only integer, slice, and tuple indexing is supported" )
1232+
1233+ return self ._getitem (* k , ** self .kwargs )
1234+
1235+ def __repr__ (self ):
1236+ return "LazyTensor(shape={},attributes=({}))" .format (self .shape , ',' .join (self .kwargs .keys ()))
0 commit comments