diff --git a/umap/parametric_umap.py b/umap/parametric_umap.py index 59a1f3da..f8768a12 100644 --- a/umap/parametric_umap.py +++ b/umap/parametric_umap.py @@ -461,7 +461,7 @@ def __getstate__(self): and k not in ("optimizer", "encoder", "decoder", "parametric_model") ) - def save(self, save_location, verbose=True): + def save(self, save_location, verbose=True, exclude_raw_data=False): # save encoder if self.encoder is not None: @@ -486,6 +486,18 @@ def save(self, save_location, verbose=True): if verbose: print("Keras full model saved to {}".format(parametric_model_output)) + # Temporarily delete the raw data in the object, before saving it, + # backing it up in raw_data + raw_data = {} + if exclude_raw_data: + if hasattr(self, "_raw_data"): + raw_data['root'] = self._raw_data + del self._raw_data + if hasattr(self, "knn_search_index") and hasattr(self.knn_search_index, + "_raw_data"): + raw_data['knn'] = self.knn_search_index._raw_data + del self.knn_search_index._raw_data + # # save model.pkl (ignoring unpickleable warnings) with catch_warnings(): filterwarnings("ignore") @@ -495,6 +507,13 @@ def save(self, save_location, verbose=True): if verbose: print("Pickle of ParametricUMAP model saved to {}".format(model_output)) + # Restore the original raw data to the object in memory + if exclude_raw_data: + if 'root' in raw_data: + self._raw_data = raw_data['root'] + if 'knn' in raw_data: + self.knn_search_index._raw_data = raw_data['knn'] + def add_landmarks( self, X,