diff --git a/.github/workflows/testing_ci.yml b/.github/workflows/testing_ci.yml index ad5d023a..eebf29e4 100644 --- a/.github/workflows/testing_ci.yml +++ b/.github/workflows/testing_ci.yml @@ -20,7 +20,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - python-version: ["3.7", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/testing_daily.yml b/.github/workflows/testing_daily.yml index 87bfc73c..35e05fb2 100644 --- a/.github/workflows/testing_daily.yml +++ b/.github/workflows/testing_daily.yml @@ -52,7 +52,8 @@ jobs: - name: Test with pytest run: | - coverage run --source=pypots -m pytest + coverage run --source=pypots -m pytest --ignore tests/test_training_on_multi_gpus.py + # ignore the test_training_on_multi_gpus.py because it requires multiple GPUs which are not available on GitHub Actions - name: Generate the LCOV report run: | diff --git a/docs/install.rst b/docs/install.rst index 93abd721..ee1967bc 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -68,7 +68,7 @@ GPU Acceleration Neural-network models in PyPOTS are implemented in PyTorch. So far we only support CUDA-enabled GPUs for GPU acceleration. If you have a CUDA device, you can install PyTorch with GPU support to accelerate the training and inference of neural-network models. After that, you can set the ``device`` argument to ``"cuda"`` when initializing the model to enable GPU acceleration. -If you don't specify ``device``, PyPOTS will automatically detect and use the first CUDA device (i.e. ``cuda:0``) if multiple CUDA devices are available. +If you don't specify ``device``, PyPOTS will automatically detect and use the default CUDA device if multiple CUDA devices are available. CPU Acceleration **************** diff --git a/pypots/base.py b/pypots/base.py index 8b208b8c..1d97d741 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -7,6 +7,7 @@ import os from abc import ABC +from datetime import datetime from typing import Optional, Union import torch @@ -22,9 +23,11 @@ class BaseModel(ABC): Parameters ---------- device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -56,7 +59,7 @@ class BaseModel(ABC): def __init__( self, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -73,28 +76,63 @@ def __init__( self.summary_writer = None # set up the device for model running below + self._setup_device(device) + + # set up saving_path to save the trained model and training logs + self._setup_path(saving_path) + + def _setup_device(self, device): if device is None: - # if it is None, then - self.device = torch.device( - "cuda" - if torch.cuda.is_available() and torch.cuda.device_count() > 0 - else "cpu" - ) + # if it is None, then use the first cuda device if cuda is available, otherwise use cpu + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") logger.info(f"No given device, using default device: {self.device}") else: if isinstance(device, str): - self.device = torch.device(device) + self.device = torch.device(device.lower()) elif isinstance(device, torch.device): self.device = device + elif isinstance(device, list): + # parallely training on multiple CUDA devices + device_list = [] + for idx, d in enumerate(device): + if isinstance(d, str): + d = d.lower() + assert ( + "cuda" in d + ), "The feature of training on multiple devices currently only support CUDA devices." + device_list.append(torch.device(d)) + elif isinstance(d, torch.device): + assert ( + "cuda" in d.type + ), "The feature of training on multiple devices currently only support CUDA devices." + device_list.append(d) + else: + raise TypeError( + f"Devices in the list should be str or torch.device, " + f"but the device with index {idx} is {type(d)}." + ) + if len(device_list) > 1: + self.device = device_list + else: + self.device = device_list[0] else: raise TypeError( - f"device should be str or torch.device, but got {type(device)}" + f"device should be str/torch.device/a list containing str or torch.device, but got {type(device)}" ) - # set up saving_path to save the trained model and training logs - if isinstance(saving_path, str): - from datetime import datetime + # check CUDA availability if using CUDA + if (isinstance(self.device, list) and "cuda" in self.device[0].type) or ( + isinstance(self.device, torch.device) and "cuda" in self.device.type + ): + assert ( + torch.cuda.is_available() and torch.cuda.device_count() > 0 + ), "You are trying to use CUDA for model training, but CUDA is not available in your environment." + def _setup_path(self, saving_path): + if isinstance(saving_path, str): # get the current time to append to saving_path, # so you can use the same saving_path to run multiple times # and also be aware of when they were run @@ -109,9 +147,35 @@ def __init__( tb_saving_path, filename_suffix=".pypots", ) + logger.info(f"Model files will be saved to {self.saving_path}") + logger.info(f"Tensorboard file will be saved to {tb_saving_path}") + else: + logger.info( + "saving_path not given. Model files and tensorboard file will not be saved." + ) + + def _send_model_to_given_device(self): + if isinstance(self.device, list): + # parallely training on multiple devices + self.model = torch.nn.DataParallel(self.model, device_ids=self.device) + self.model = self.model.cuda() + logger.info( + f"Model has been allocated to the given multiple devices: {self.device}" + ) + else: + self.model = self.model.to(self.device) + + def _send_data_to_given_device(self, data): + if isinstance(self.device, torch.device): # single device + data = map(lambda x: x.to(self.device), data) + else: # parallely training on multiple devices - logger.info(f"the trained model will be saved to {self.saving_path}") - logger.info(f"the tensorboard file will be saved to {tb_saving_path}") + # randomly choose one device to balance the workload + # device = np.random.choice(self.device) + + data = map(lambda x: x.cuda(), data) + + return data def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None: """Saving training logs into the tensorboard file specified by the given path `tb_file_saving_path`. @@ -135,7 +199,7 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None # save all items containing "loss" or "error" in the name # WDU: may enable customization keywords in the future if ("loss" in item_name) or ("error" in item_name): - self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step) + self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step) def save_model( self, @@ -175,7 +239,11 @@ def save_model( logger.error(f"File {saving_path} exists. Saving operation aborted.") try: create_dir_if_not_exist(saving_dir) - torch.save(self.model, saving_path) + if isinstance(self.device, list): + # to save a DataParallel model generically, save the model.module.state_dict() + torch.save(self.model.module, saving_path) + else: + torch.save(self.model, saving_path) logger.info(f"Saved the model to {saving_path}.") except Exception as e: raise RuntimeError( @@ -226,9 +294,15 @@ def load_model(self, model_path: str) -> None: assert os.path.exists(model_path), f"Model file {model_path} does not exist." try: - loaded_model = torch.load(model_path, map_location=self.device) + if isinstance(self.device, torch.device): + loaded_model = torch.load(model_path, map_location=self.device) + else: + loaded_model = torch.load(model_path) if isinstance(loaded_model, torch.nn.Module): - self.model.load_state_dict(loaded_model.state_dict()) + if isinstance(self.device, torch.device): + self.model.load_state_dict(loaded_model.state_dict()) + else: + self.model.module.load_state_dict(loaded_model.state_dict()) else: self.model = loaded_model.model except Exception as e: @@ -257,9 +331,11 @@ class BaseNNModel(BaseModel): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -301,12 +377,11 @@ def __init__( epochs: int, patience: int, num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseModel.__init__( - self, + super().__init__( device, saving_path, model_saving_strategy, diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 82d67cd4..a30fd698 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -26,9 +26,11 @@ class BaseClassifier(BaseModel): The number of classes in the classification task. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -47,12 +49,11 @@ class BaseClassifier(BaseModel): def __init__( self, n_classes: int, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseModel.__init__( - self, + super().__init__( device, saving_path, model_saving_strategy, @@ -119,7 +120,7 @@ def classify( raise NotImplementedError -class BaseNNClassifier(BaseNNModel, BaseClassifier): +class BaseNNClassifier(BaseNNModel): """The abstract class for all neural-network classification models in PyPOTS. Parameters @@ -143,9 +144,11 @@ class BaseNNClassifier(BaseNNModel, BaseClassifier): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -177,12 +180,11 @@ def __init__( epochs: int, patience: int, num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseNNModel.__init__( - self, + super().__init__( batch_size, epochs, patience, @@ -191,13 +193,7 @@ def __init__( saving_path, model_saving_strategy, ) - BaseClassifier.__init__( - self, - n_classes, - device, - saving_path, - model_saving_strategy, - ) + self.n_classes = n_classes @abstractmethod def _assemble_input_for_training(self, data) -> dict: @@ -275,9 +271,9 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs) - results["loss"].backward() + results["loss"].sum().backward() self.optimizer.step() - epoch_train_loss_collector.append(results["loss"].item()) + epoch_train_loss_collector.append(results["loss"].sum().item()) # save training loss logs into the tensorboard file for every step if in need if self.summary_writer is not None: @@ -293,7 +289,9 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append(results["loss"].item()) + epoch_val_loss_collector.append( + results["loss"].sum().item() + ) mean_val_loss = np.mean(epoch_val_loss_collector) @@ -331,15 +329,15 @@ def _train_model( ) break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) @@ -347,3 +345,62 @@ def _train_model( raise ValueError("Something is wrong. best_loss is Nan after training.") logger.info("Finished training.") + + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + """Train the classifier on the given data. + + Parameters + ---------- + train_set : + The dataset for model training, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for training, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + val_set : + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if train_set and val_set are path strings. + + """ + raise NotImplementedError + + @abstractmethod + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + """Classify the input data with the trained model. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples], + Classification results of the given samples. + """ + raise NotImplementedError diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index 19ef8405..0fa4f10e 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -57,20 +57,17 @@ def __init__( def impute(self, inputs: dict) -> torch.Tensor: return super().impute(inputs) - def classify(self, inputs: dict) -> torch.Tensor: - ret_f = self.rits_f(inputs, "forward") - ret_b = self._reverse(self.rits_b(inputs, "backward")) - classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2 - return classification_pred - - def forward(self, inputs: dict) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: """Forward processing of BRITS. Parameters ---------- - inputs : dict, + inputs : The input data. + training : + Whether in training mode. + Returns ------- dict, A dictionary includes all results. @@ -78,6 +75,11 @@ def forward(self, inputs: dict) -> dict: ret_f = self.rits_f(inputs, "forward") ret_b = self._reverse(self.rits_b(inputs, "backward")) + classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2 + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} + ret_f["classification_loss"] = F.nll_loss( torch.log(ret_f["prediction"]), inputs["label"] ) @@ -101,6 +103,7 @@ def forward(self, inputs: dict) -> dict: ) results = { + "classification_pred": classification_pred, "consistency_loss": consistency_loss, "classification_loss": classification_loss, "reconstruction_loss": reconstruction_loss, @@ -152,9 +155,11 @@ class BRITS(BaseNNClassifier): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -191,7 +196,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -222,7 +227,7 @@ def __init__( self.reconstruction_weight, self.device, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -240,7 +245,7 @@ def _assemble_input_for_training(self, data: dict) -> dict: back_missing_mask, back_deltas, label, - ) = map(lambda x: x.to(self.device), data) + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -272,7 +277,7 @@ def _assemble_input_for_testing(self, data: dict) -> dict: back_X, back_missing_mask, back_deltas, - ) = map(lambda x: x.to(self.device), data) + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -336,7 +341,8 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py"): with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - classification_pred = self.model.classify(inputs) + results = self.model.forward(inputs, training=False) + classification_pred = results["classification_pred"] prediction_collector.append(classification_pred) predictions = torch.cat(prediction_collector) diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index ac40e09b..1e13e10b 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -23,6 +23,7 @@ from ...imputation.brits.modules import TemporalDecay from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _GRUD(nn.Module): @@ -53,7 +54,22 @@ def __init__( ) self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes) - def classify(self, inputs: dict) -> torch.Tensor: + def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of GRU-D. + + Parameters + ---------- + inputs : + The input data. + + training : + Whether in training mode. + + Returns + ------- + dict, + A dictionary includes all results. + """ values = inputs["X"] masks = inputs["missing_mask"] deltas = inputs["deltas"] @@ -61,7 +77,7 @@ def classify(self, inputs: dict) -> torch.Tensor: X_filledLOCF = inputs["X_filledLOCF"] hidden_state = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=self.device + (values.size()[0], self.rnn_hidden_size), device=values.device ) for t in range(self.n_steps): @@ -77,29 +93,26 @@ def classify(self, inputs: dict) -> torch.Tensor: x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean x_replaced = m * x + (1 - m) * x_h - inputs = torch.cat([x_replaced, hidden_state, m], dim=1) - hidden_state = self.rnn_cell(inputs, hidden_state) + data_input = torch.cat([x_replaced, hidden_state, m], dim=1) + hidden_state = self.rnn_cell(data_input, hidden_state) logits = self.classifier(hidden_state) - prediction = torch.softmax(logits, dim=1) - return prediction + classification_pred = torch.softmax(logits, dim=1) - def forward(self, inputs: dict) -> dict: - """Forward processing of GRU-D. + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} - Parameters - ---------- - inputs : - The input data. + torch.log(classification_pred) + logger.error(f"ZShape {classification_pred.shape}") + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) - Returns - ------- - dict, - A dictionary includes all results. - """ - prediction = self.classify(inputs) - classification_loss = F.nll_loss(torch.log(prediction), inputs["label"]) - results = {"prediction": prediction, "loss": classification_loss} + results = { + "classification_pred": classification_pred, + "loss": classification_loss, + } return results @@ -140,9 +153,11 @@ class GRUD(BaseNNClassifier): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -176,7 +191,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -203,7 +218,7 @@ def __init__( self.n_classes, self.device, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -212,9 +227,15 @@ def __init__( def _assemble_input_for_training(self, data: dict) -> dict: # fetch data - indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean, label = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X, + X_filledLOCF, + missing_mask, + deltas, + empirical_mean, + label, + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -232,9 +253,14 @@ def _assemble_input_for_validating(self, data: dict) -> dict: return self._assemble_input_for_training(data) def _assemble_input_for_testing(self, data: dict) -> dict: - indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X, + X_filledLOCF, + missing_mask, + deltas, + empirical_mean, + ) = self._send_data_to_given_device(data) inputs = { "indices": indices, @@ -293,7 +319,8 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - prediction = self.model.classify(inputs) + results = self.model.forward(inputs, training=False) + prediction = results["classification_pred"] prediction_collector.append(prediction) predictions = torch.cat(prediction_collector) diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 770aeb7b..7e3b3aa1 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -270,12 +270,18 @@ def classify(self, inputs: dict) -> torch.Tensor: return prediction - def forward(self, inputs): - prediction = self.classify(inputs) - classification_loss = F.nll_loss(torch.log(prediction), inputs["label"]) + def forward(self, inputs, training=True): + classification_pred = self.classify(inputs) + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} + + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) results = { - "prediction": prediction, + "prediction": classification_pred, "loss": classification_loss # 'distance': distance, } @@ -345,9 +351,11 @@ class Raindrop(BaseNNClassifier): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -390,7 +398,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -424,7 +432,7 @@ def __init__( static=static, device=self.device, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -433,9 +441,15 @@ def __init__( def _assemble_input_for_training(self, data: dict) -> dict: # fetch data - indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean, label = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X, + X_filledLOCF, + missing_mask, + deltas, + empirical_mean, + label, + ) = self._send_data_to_given_device(data) bz, n_steps, n_features = X.shape lengths = torch.tensor([n_steps] * bz, dtype=torch.float) @@ -459,9 +473,14 @@ def _assemble_input_for_validating(self, data: dict) -> dict: return self._assemble_input_for_training(data) def _assemble_input_for_testing(self, data: dict) -> dict: - indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X, + X_filledLOCF, + missing_mask, + deltas, + empirical_mean, + ) = self._send_data_to_given_device(data) bz, n_steps, n_features = X.shape lengths = torch.tensor([n_steps] * bz, dtype=torch.float) times = torch.tensor(range(n_steps), dtype=torch.float).repeat(bz, 1) @@ -526,7 +545,8 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - prediction = self.model.classify(inputs) + results = self.model.forward(inputs, training=False) + prediction = results["classification_pred"] prediction_collector.append(prediction) predictions = torch.cat(prediction_collector) diff --git a/pypots/classification/template/model.py b/pypots/classification/template/model.py index 6063d0fc..f7e2d15a 100644 --- a/pypots/classification/template/model.py +++ b/pypots/classification/template/model.py @@ -58,7 +58,7 @@ def __init__( patience: int, num_workers: int = 0, optimizer: Optional[Optimizer] = Adam(), - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index fc3c678a..324e6718 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -26,9 +26,11 @@ class BaseClusterer(BaseModel): The number of clusters in the clustering task. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -47,7 +49,7 @@ class BaseClusterer(BaseModel): def __init__( self, n_clusters: int, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -106,7 +108,7 @@ def cluster( raise NotImplementedError -class BaseNNClusterer(BaseNNModel, BaseClusterer): +class BaseNNClusterer(BaseNNModel): """The abstract class for all neural-network clustering models in PyPOTS. Parameters @@ -130,9 +132,11 @@ class BaseNNClusterer(BaseNNModel, BaseClusterer): ``0`` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -164,12 +168,11 @@ def __init__( epochs: int, patience: int, num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseNNModel.__init__( - self, + super().__init__( batch_size, epochs, patience, @@ -178,13 +181,7 @@ def __init__( saving_path, model_saving_strategy, ) - BaseClusterer.__init__( - self, - n_clusters, - device, - saving_path, - model_saving_strategy, - ) + self.n_clusters = n_clusters @abstractmethod def _assemble_input_for_training(self, data: list) -> dict: @@ -274,9 +271,9 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs) - results["loss"].backward() + results["loss"].sum().backward() self.optimizer.step() - epoch_train_loss_collector.append(results["loss"].item()) + epoch_train_loss_collector.append(results["loss"].sum().item()) # mean training loss of the current epoch mean_train_loss = np.mean(epoch_train_loss_collector) @@ -288,7 +285,9 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append(results["loss"].item()) + epoch_val_loss_collector.append( + results["loss"].sum().item() + ) mean_val_loss = np.mean(epoch_val_loss_collector) logger.info( @@ -313,15 +312,15 @@ def _train_model( ) break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) @@ -329,3 +328,50 @@ def _train_model( raise ValueError("Something is wrong. best_loss is Nan after training.") logger.info("Finished training.") + + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + file_type: str = "h5py", + ) -> None: + """Train the cluster. + + Parameters + ---------- + train_set : + The dataset for model training, should be a dictionary including the key 'X', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for training, can contain missing values. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include the key 'X'. + + file_type : + The type of the given file if train_set is a path string. + """ + raise NotImplementedError + + @abstractmethod + def cluster( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Cluster the input with the trained model. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, + Clustering results. + """ + raise NotImplementedError diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index bf13e9e9..2c83d05d 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -73,7 +73,12 @@ def cluster(self, inputs: dict, training_object: str = "generator") -> dict: inputs["fcn_latent"] = fcn_latent return inputs - def forward(self, inputs: dict, training_object: str = "generator") -> dict: + def forward( + self, + inputs: dict, + training_object: str = "generator", + mode: str = "training", + ) -> dict: assert training_object in [ "generator", "discriminator", @@ -84,6 +89,10 @@ def forward(self, inputs: dict, training_object: str = "generator") -> dict: batch_size, n_steps, n_features = X.shape losses = {} inputs = self.cluster(inputs, training_object) + if mode == "clustering": + # if only run clustering, then no need to calculate loss + return inputs + if training_object == "discriminator": l_D = F.binary_cross_entropy_with_logits( inputs["discrimination"], missing_mask @@ -167,9 +176,11 @@ class CRLI(BaseNNClusterer): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -211,7 +222,7 @@ def __init__( G_optimizer: Optional[Optimizer] = Adam(), D_optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: Optional[str] = None, model_saving_strategy: Optional[str] = "best", ): @@ -245,7 +256,7 @@ def __init__( rnn_cell_type, self.device, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -261,7 +272,7 @@ def __init__( def _assemble_input_for_training(self, data: list) -> dict: # fetch data - indices, X, missing_mask = map(lambda x: x.to(self.device), data) + indices, X, missing_mask = self._send_data_to_given_device(data) inputs = { "X": X, @@ -362,15 +373,15 @@ def _train_model( ) break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) diff --git a/pypots/clustering/template/model.py b/pypots/clustering/template/model.py index 498e8d10..d9b40cff 100644 --- a/pypots/clustering/template/model.py +++ b/pypots/clustering/template/model.py @@ -58,7 +58,7 @@ def __init__( patience: int, num_workers: int = 0, optimizer: Optional[Optimizer] = Adam(), - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 6f757635..baf9d527 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -163,7 +163,12 @@ def get_results( mu_c, var_c, phi_c = self.gmm_layer() return X_reconstructed, mu_c, var_c, phi_c, z, mu_tilde, stddev_tilde - def cluster(self, inputs: dict) -> np.ndarray: + def forward( + self, + inputs: dict, + pretrain: bool = False, + mode: str = "training", + ) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] ( X_reconstructed, @@ -175,39 +180,34 @@ def cluster(self, inputs: dict) -> np.ndarray: stddev_tilde, ) = self.get_results(X, missing_mask) - def func_to_apply( - mu_t_: np.ndarray, mu_: np.ndarray, stddev_: np.ndarray, phi_: np.ndarray - ) -> np.ndarray: - # the covariance matrix is diagonal, so we can just take the product - return np.log(self.eps + phi_) + np.log( - self.eps - + multivariate_normal.pdf(mu_t_, mean=mu_, cov=np.diag(stddev_)) - ) - - mu_tilde = mu_tilde.detach().cpu().numpy() - mu = mu_c.detach().cpu().numpy() - var = var_c.detach().cpu().numpy() - phi = phi_c.detach().cpu().numpy() - p = np.array( - [ - func_to_apply(mu_tilde, mu[i], var[i], phi[i]) - for i in np.arange(mu.shape[0]) - ] - ) - clustering_results = np.argmax(p, axis=0) - return clustering_results + if mode == "clustering": + + def func_to_apply( + mu_t_: np.ndarray, + mu_: np.ndarray, + stddev_: np.ndarray, + phi_: np.ndarray, + ) -> np.ndarray: + # the covariance matrix is diagonal, so we can just take the product + return np.log(self.eps + phi_) + np.log( + self.eps + + multivariate_normal.pdf(mu_t_, mean=mu_, cov=np.diag(stddev_)) + ) - def forward(self, inputs: dict, pretrain: bool = False) -> dict: - X, missing_mask = inputs["X"], inputs["missing_mask"] - ( - X_reconstructed, - mu_c, - var_c, - phi_c, - z, - mu_tilde, - stddev_tilde, - ) = self.get_results(X, missing_mask) + mu_tilde = mu_tilde.detach().cpu().numpy() + mu = mu_c.detach().cpu().numpy() + var = var_c.detach().cpu().numpy() + phi = phi_c.detach().cpu().numpy() + p = np.array( + [ + func_to_apply(mu_tilde, mu[i], var[i], phi[i]) + for i in np.arange(mu.shape[0]) + ] + ) + clustering_results = np.argmax(p, axis=0) + results = {"clustering_pred": clustering_results} + # if only run clustering, then no need to calculate loss + return results device = X.device @@ -275,7 +275,10 @@ def forward(self, inputs: dict, pretrain: bool = False) -> dict: latent_loss3 = latent_loss3.mean() latent_loss = latent_loss1 + latent_loss2 + latent_loss3 - results = {"loss": reconstruction_loss + self.alpha * latent_loss, "z": z} + results = { + "loss": reconstruction_loss + self.alpha * latent_loss, + "z": z, + } return results @@ -367,7 +370,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -389,7 +392,7 @@ def __init__( self.model = _VaDER( n_steps, n_features, n_clusters, rnn_hidden_size, d_mu_stddev ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -398,7 +401,7 @@ def __init__( def _assemble_input_for_training(self, data: list) -> dict: # fetch data - indices, X, missing_mask = map(lambda x: x.to(self.device), data) + indices, X, missing_mask = self._send_data_to_given_device(data) inputs = { "X": X, @@ -432,7 +435,7 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs, pretrain=True) - results["loss"].backward() + results["loss"].sum().backward() self.optimizer.step() # save pre-training loss logs into the tensorboard file for every step if in need @@ -486,16 +489,23 @@ def _train_model( continue # get GMM parameters - phi = np.log(gmm.weights_ + 1e-9) # inverse softmax mu = gmm.means_ var = inverse_softplus(gmm.covariances_) - # use trained GMM's parameters to init GMM layer's - self.model.gmm_layer.set_values( - torch.from_numpy(mu).to(self.device), - torch.from_numpy(var).to(self.device), - torch.from_numpy(phi).to(self.device), - ) + phi = np.log(gmm.weights_ + 1e-9) # inverse softmax + # use trained GMM's parameters to init GMM layer's + if isinstance(self.device, list): # if using multi-GPU + self.model.module.gmm_layer.set_values( + torch.from_numpy(mu).to(results["z"].device), + torch.from_numpy(var).to(results["z"].device), + torch.from_numpy(phi).to(results["z"].device), + ) + else: + self.model.gmm_layer.set_values( + torch.from_numpy(mu).to(results["z"].device), + torch.from_numpy(var).to(results["z"].device), + torch.from_numpy(phi).to(results["z"].device), + ) try: training_step = 0 for epoch in range(self.epochs): @@ -506,9 +516,9 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs) - results["loss"].backward() + results["loss"].sum().backward() self.optimizer.step() - epoch_train_loss_collector.append(results["loss"].item()) + epoch_train_loss_collector.append(results["loss"].sum().item()) # save training loss logs into the tensorboard file for every step if in need if self.summary_writer is not None: @@ -524,7 +534,9 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append(results["loss"].item()) + epoch_val_loss_collector.append( + results["loss"].sum().item() + ) mean_val_loss = np.mean(epoch_val_loss_collector) @@ -562,15 +574,15 @@ def _train_model( ) break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) @@ -617,7 +629,9 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.cluster(inputs) + results = self.model.forward(inputs, mode="clustering")[ + "clustering_pred" + ] clustering_results_collector.append(results) clustering_results = np.concatenate(clustering_results_collector) diff --git a/pypots/clustering/vader/modules.py b/pypots/clustering/vader/modules.py index 16b96bf0..175a6ea5 100644 --- a/pypots/clustering/vader/modules.py +++ b/pypots/clustering/vader/modules.py @@ -94,7 +94,7 @@ def __init__(self, d_hidden: int, n_clusters: int): super().__init__() self.mu_c_unscaled = Parameter(torch.Tensor(n_clusters, d_hidden)) self.var_c_unscaled = Parameter(torch.Tensor(n_clusters, d_hidden)) - self.phi_c_unscaled = torch.Tensor(n_clusters) + self.phi_c_unscaled = Parameter(torch.Tensor(n_clusters)) def set_values( self, @@ -107,7 +107,7 @@ def set_values( assert phi.shape == self.phi_c_unscaled.shape self.mu_c_unscaled = torch.nn.Parameter(mu) self.var_c_unscaled = torch.nn.Parameter(var) - self.phi_c_unscaled = phi + self.phi_c_unscaled = torch.nn.Parameter(phi) def forward(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu_c = self.mu_c_unscaled diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 2b102702..5188999b 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -23,9 +23,11 @@ class BaseForecaster(BaseModel): Parameters ---------- device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -43,12 +45,11 @@ class BaseForecaster(BaseModel): def __init__( self, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseModel.__init__( - self, + super().__init__( device, saving_path, model_saving_strategy, @@ -111,7 +112,7 @@ def forecast( raise NotImplementedError -class BaseNNForecaster(BaseNNModel, BaseForecaster): +class BaseNNForecaster(BaseNNModel): """The abstract class for all neural-network forecasting models in PyPOTS. Parameters @@ -132,9 +133,11 @@ class BaseNNForecaster(BaseNNModel, BaseForecaster): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -164,12 +167,11 @@ def __init__( epochs: int, patience: int, num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseNNModel.__init__( - self, + super().__init__( batch_size, epochs, patience, @@ -178,12 +180,6 @@ def __init__( saving_path, model_saving_strategy, ) - BaseForecaster.__init__( - self, - device, - saving_path, - model_saving_strategy, - ) @abstractmethod def _assemble_input_for_training(self, data) -> dict: @@ -261,9 +257,9 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs) - results["loss"].backward() + results["loss"].sum().backward() self.optimizer.step() - epoch_train_loss_collector.append(results["loss"].item()) + epoch_train_loss_collector.append(results["loss"].sum().item()) # save training loss logs into the tensorboard file for every step if in need if self.summary_writer is not None: @@ -279,7 +275,9 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append(results["loss"].item()) + epoch_val_loss_collector.append( + results["loss"].sum().item() + ) mean_val_loss = np.mean(epoch_val_loss_collector) @@ -312,15 +310,15 @@ def _train_model( ) break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) @@ -328,3 +326,59 @@ def _train_model( raise ValueError("Something is wrong. best_loss is Nan after training.") logger.info("Finished training.") + + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + """Train the classifier on the given data. + + Parameters + ---------- + train_set : + The dataset for model training, should be a dictionary including the key 'X', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for training, can contain missing values. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include the key 'X'. + + val_set : + The dataset for model validating, should be a dictionary including the key 'X', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validation, can contain missing values. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include the key 'X'. + + file_type : + The type of the given file if train_set and val_set are path strings. + + """ + raise NotImplementedError + + @abstractmethod + def forecast( + self, + X: dict or str, + file_type: str = "h5py", + ) -> np.ndarray: + """Forecast the future the input with the trained model. + + Parameters + ---------- + X : + Time-series data containing missing values. Shape [n_samples, sequence length (time steps), n_features]. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, prediction_horizon, n_features], + Forecasting results. + """ + raise NotImplementedError diff --git a/pypots/forecasting/bttf/model.py b/pypots/forecasting/bttf/model.py index 2e929584..500412a9 100644 --- a/pypots/forecasting/bttf/model.py +++ b/pypots/forecasting/bttf/model.py @@ -297,10 +297,12 @@ class BTTF(BaseForecaster): multi_step : int, default = 1, The number of time steps to forecast at each iteration. - device : str or `torch.device`, default = None, - The device for the model to run on. - If not given, will try to use CUDA devices first (will use the GPU with device number 0 only by default), + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. Notes @@ -321,7 +323,7 @@ def __init__( burn_iter: int, gibbs_iter: int, multi_step: int = 1, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, ): super().__init__(device) self.n_steps = n_steps diff --git a/pypots/forecasting/template/model.py b/pypots/forecasting/template/model.py index c77bb587..f761d795 100644 --- a/pypots/forecasting/template/model.py +++ b/pypots/forecasting/template/model.py @@ -57,7 +57,7 @@ def __init__( patience: int, num_workers: int = 0, optimizer: Optional[Optimizer] = Adam(), - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index d8bae6a9..76988c7b 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -30,9 +30,11 @@ class BaseImputer(BaseModel): Parameters ---------- device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -50,12 +52,11 @@ class BaseImputer(BaseModel): def __init__( self, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseModel.__init__( - self, + super().__init__( device, saving_path, model_saving_strategy, @@ -119,7 +120,7 @@ def impute( raise NotImplementedError -class BaseNNImputer(BaseNNModel, BaseImputer): +class BaseNNImputer(BaseNNModel): """The abstract class for all neural-network imputation models in PyPOTS. Parameters @@ -140,9 +141,11 @@ class BaseNNImputer(BaseNNModel, BaseImputer): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -172,12 +175,11 @@ def __init__( epochs: int, patience: int, num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): - BaseNNModel.__init__( - self, + super().__init__( batch_size, epochs, patience, @@ -186,12 +188,6 @@ def __init__( saving_path, model_saving_strategy, ) - BaseImputer.__init__( - self, - device, - saving_path, - model_saving_strategy, - ) @abstractmethod def _assemble_input_for_training(self, data: list) -> dict: @@ -268,9 +264,10 @@ def _train_model( inputs = self._assemble_input_for_training(data) self.optimizer.zero_grad() results = self.model.forward(inputs) - results["loss"].backward() + # use sum() before backward() in case of multi-gpu training + results["loss"].sum().backward() self.optimizer.step() - epoch_train_loss_collector.append(results["loss"].item()) + epoch_train_loss_collector.append(results["loss"].sum().item()) # save training loss logs into the tensorboard file for every step if in need if self.summary_writer is not None: @@ -285,7 +282,8 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - imputed_data = self.model.impute(inputs) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) imputation_collector = torch.cat(imputation_collector) @@ -339,15 +337,15 @@ def _train_model( break except Exception as e: - logger.info(f"Exception: {e}") + logger.error(f"Exception: {e}") if self.best_model_dict is None: raise RuntimeError( - "Training got interrupted. Model was not get trained. Please try fit() again." + "Training got interrupted. Model was not trained. Please investigate the error printed above." ) else: RuntimeWarning( - "Training got interrupted. " - "Model will load the best parameters so far for testing. " + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" "If you don't want it, please try fit() again." ) @@ -355,3 +353,60 @@ def _train_model( raise ValueError("Something is wrong. best_loss is Nan after training.") logger.info("Finished training.") + + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + """Train the imputer on the given data. + + Parameters + ---------- + train_set : + The dataset for model training, should be a dictionary including the key 'X', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for training, can contain missing values. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include the key 'X'. + + val_set : + The dataset for model validating, should be a dictionary including the key 'X', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include the key 'X'. + + file_type : str, default = "h5py", + The type of the given file if train_set and val_set are path strings. + + """ + raise NotImplementedError + + @abstractmethod + def impute( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + raise NotImplementedError diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 568eadc3..0627a889 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -135,14 +135,14 @@ def impute( # create hidden states and cell states for the lstm cell hidden_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=self.device + (values.size()[0], self.rnn_hidden_size), device=values.device ) cell_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=self.device + (values.size()[0], self.rnn_hidden_size), device=values.device ) estimations = [] - reconstruction_loss = torch.tensor(0.0).to(self.device) + reconstruction_loss = torch.tensor(0.0).to(values.device) # imputation period for t in range(self.n_steps): @@ -202,7 +202,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict: ret_dict = { "consistency_loss": torch.tensor( - 0.0, device=self.device + 0.0, device=imputed_data.device ), # single direction, has no consistency loss "reconstruction_loss": reconstruction_loss, "imputed_data": imputed_data, @@ -303,28 +303,7 @@ def reverse_tensor(tensor_): return ret - def impute(self, inputs: dict) -> torch.Tensor: - """Impute the missing data. Only impute, this is for test stage. - - Parameters - ---------- - inputs : - A dictionary includes all input data. - - Returns - ------- - array-like, the same shape with the input feature vectors. - The feature vectors with missing part imputed. - - """ - imputed_data_f, _, _ = self.rits_f.impute(inputs, "forward") - imputed_data_b, _, _ = self.rits_b.impute(inputs, "backward") - imputed_data_b = {"imputed_data_b": imputed_data_b} - imputed_data_b = self._reverse(imputed_data_b)["imputed_data_b"] - imputed_data = (imputed_data_f + imputed_data_b) / 2 - return imputed_data - - def forward(self, inputs: dict) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: """Forward processing of BRITS. Parameters @@ -341,10 +320,17 @@ def forward(self, inputs: dict) -> dict: # Results from the backward RITS. ret_b = self._reverse(self.rits_b(inputs, "backward")) + imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + consistency_loss = self._get_consistency_loss( ret_f["imputed_data"], ret_b["imputed_data"] ) - imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 # `loss` is always the item for backward propagating to update the model loss = ( @@ -390,9 +376,11 @@ class BRITS(BaseNNImputer): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -426,7 +414,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -446,9 +434,12 @@ def __init__( # set up the model self.model = _BRITS( - self.n_steps, self.n_features, self.rnn_hidden_size, self.device + self.n_steps, + self.n_features, + self.rnn_hidden_size, + self.device, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -457,9 +448,15 @@ def __init__( def _assemble_input_for_training(self, data: list) -> dict: # fetch data - indices, X, missing_mask, deltas, back_X, back_missing_mask, back_deltas = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X, + missing_mask, + deltas, + back_X, + back_missing_mask, + back_deltas, + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -491,7 +488,9 @@ def fit( file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForBRITS(train_set, file_type=file_type) + training_set = DatasetForBRITS( + train_set, return_labels=False, file_type=file_type + ) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -513,7 +512,7 @@ def fit( "X_intact": hf["X_intact"][:], "indicating_mask": hf["indicating_mask"][:], } - val_set = DatasetForBRITS(val_set, file_type=file_type) + val_set = DatasetForBRITS(val_set, return_labels=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -547,7 +546,8 @@ def impute( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - imputed_data = self.model.impute(inputs) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) imputation_collector = torch.cat(imputation_collector) diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 3df3e14e..d00ab610 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -146,15 +146,17 @@ def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]: return X_c, [X_tilde_1, X_tilde_2, X_tilde_3] - def impute(self, inputs: dict) -> torch.Tensor: - imputed_data, _ = self._process(inputs) - return imputed_data - - def forward(self, inputs: dict) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - ORT_loss = 0 imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(inputs) + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + ORT_loss = 0 ORT_loss += cal_mae(X_tilde_1, X, masks) ORT_loss += cal_mae(X_tilde_2, X, masks) ORT_loss += cal_mae(X_tilde_3, X, masks) @@ -244,9 +246,11 @@ class SAITS(BaseNNImputer): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -290,7 +294,7 @@ def __init__( patience: Optional[int] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: Optional[str] = None, model_saving_strategy: Optional[str] = "best", ): @@ -335,17 +339,21 @@ def __init__( self.ORT_weight, self.MIT_weight, ) - self.model = self.model.to(self.device) self._print_model_size() + self._send_model_to_given_device() # set up the optimizer self.optimizer = optimizer self.optimizer.init_optimizer(self.model.parameters()) def _assemble_input_for_training(self, data: list) -> dict: - indices, X_intact, X, missing_mask, indicating_mask = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X_intact, + X, + missing_mask, + indicating_mask, + ) = self._send_data_to_given_device(data) inputs = { "X": X, @@ -357,7 +365,7 @@ def _assemble_input_for_training(self, data: list) -> dict: return inputs def _assemble_input_for_validating(self, data) -> dict: - indices, X, missing_mask = map(lambda x: x.to(self.device), data) + indices, X, missing_mask = self._send_data_to_given_device(data) inputs = { "X": X, @@ -375,7 +383,9 @@ def fit( file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForSAITS(train_set, file_type=file_type) + training_set = DatasetForSAITS( + train_set, return_labels=False, file_type=file_type + ) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -398,7 +408,7 @@ def fit( "indicating_mask": hf["indicating_mask"][:], } - val_set = BaseDataset(val_set, file_type=file_type) + val_set = BaseDataset(val_set, return_labels=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -434,7 +444,8 @@ def impute( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - imputed_data = self.model.impute(inputs) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) # Step 3: output collection and return diff --git a/pypots/imputation/template/model.py b/pypots/imputation/template/model.py index 9b17c540..bb9107d7 100644 --- a/pypots/imputation/template/model.py +++ b/pypots/imputation/template/model.py @@ -57,7 +57,7 @@ def __init__( patience: int, num_workers: int = 0, optimizer: Optional[Optimizer] = Adam(), - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index f4ca0418..fd5e103b 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -90,15 +90,17 @@ def _process(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: ) # replace non-missing part with original data return imputed_data, learned_presentation - def impute(self, inputs: dict) -> torch.Tensor: - imputed_data, _ = self._process(inputs) - return imputed_data - - def forward(self, inputs: dict) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] imputed_data, learned_presentation = self._process(inputs) - ORT_loss = cal_mae(learned_presentation, X, masks) + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + ORT_loss = cal_mae(learned_presentation, X, masks) MIT_loss = cal_mae( learned_presentation, inputs["X_intact"], inputs["indicating_mask"] ) @@ -184,9 +186,11 @@ class Transformer(BaseNNImputer): `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : @@ -229,7 +233,7 @@ def __init__( patience: int = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Optional[str] = "best", ): @@ -271,7 +275,7 @@ def __init__( self.ORT_weight, self.MIT_weight, ) - self.model = self.model.to(self.device) + self._send_model_to_given_device() self._print_model_size() # set up the optimizer @@ -279,9 +283,13 @@ def __init__( self.optimizer.init_optimizer(self.model.parameters()) def _assemble_input_for_training(self, data: dict) -> dict: - indices, X_intact, X, missing_mask, indicating_mask = map( - lambda x: x.to(self.device), data - ) + ( + indices, + X_intact, + X, + missing_mask, + indicating_mask, + ) = self._send_data_to_given_device(data) inputs = { "X": X, @@ -293,7 +301,7 @@ def _assemble_input_for_training(self, data: dict) -> dict: return inputs def _assemble_input_for_validating(self, data: list) -> dict: - indices, X, missing_mask = map(lambda x: x.to(self.device), data) + indices, X, missing_mask = self._send_data_to_given_device(data) inputs = { "X": X, @@ -312,7 +320,9 @@ def fit( file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForSAITS(train_set, file_type=file_type) + training_set = DatasetForSAITS( + train_set, return_labels=False, file_type=file_type + ) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -335,7 +345,7 @@ def fit( "indicating_mask": hf["indicating_mask"][:], } - val_set = BaseDataset(val_set, file_type=file_type) + val_set = BaseDataset(val_set, return_labels=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -365,7 +375,8 @@ def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - imputed_data = self.model.impute(inputs) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) imputation_collector = torch.cat(imputation_collector) diff --git a/tests/test_training_on_multi_gpus.py b/tests/test_training_on_multi_gpus.py new file mode 100644 index 00000000..253c7b63 --- /dev/null +++ b/tests/test_training_on_multi_gpus.py @@ -0,0 +1,700 @@ +""" +Test cases for running models on multi cuda devices. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +import torch + +from pypots.classification import BRITS, GRUD, Raindrop +from pypots.clustering import VaDER, CRLI +from pypots.forecasting import BTTF +from pypots.imputation import ( + SAITS, + Transformer, + LOCF, +) +from pypots.imputation import BRITS as ImputationBRITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import cal_rand_index, cal_cluster_purity +from tests.global_test_config import ( + DATA, + RESULT_SAVING_DIR, + check_tb_and_model_checkpoints_existence, +) + +EPOCHS = 5 + +DEVICES = [torch.device(i) for i in range(torch.cuda.device_count())] +LESS_THAN_TWO_DEVICES = len(DEVICES) < 2 + +# global skip test if less than two cuda-enabled devices +pytestmark = pytest.mark.skipif(LESS_THAN_TWO_DEVICES, reason="not enough cuda devices") + + +TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]} + +VAL_SET = { + "X": DATA["val_X"], + "X_intact": DATA["val_X_intact"], + "indicating_mask": DATA["val_X_indicating_mask"], + "y": DATA["val_y"], +} +TEST_SET = {"X": DATA["test_X"]} + +RESULT_SAVING_DIR_FOR_IMPUTATION = os.path.join(RESULT_SAVING_DIR, "imputation") +RESULT_SAVING_DIR_FOR_CLASSIFICATION = os.path.join(RESULT_SAVING_DIR, "classification") +RESULT_SAVING_DIR_FOR_CLUSTERING = os.path.join(RESULT_SAVING_DIR, "clustering") + + +class TestSAITS(unittest.TestCase): + logger.info("Running tests for an imputation model SAITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "SAITS") + model_save_name = "saved_saits_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=256, + d_inner=128, + n_heads=4, + d_k=64, + d_v=64, + dropout=0.1, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_1_impute(self): + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-saits") + def test_2_parameters(self): + assert hasattr(self.saits, "model") and self.saits.model is not None + + assert hasattr(self.saits, "optimizer") and self.saits.optimizer is not None + + assert hasattr(self.saits, "best_loss") + self.assertNotEqual(self.saits.best_loss, float("inf")) + + assert ( + hasattr(self.saits, "best_model_dict") + and self.saits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.saits) + + # save the trained model into file, and check if the path exists + self.saits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.saits.load_model(saved_model_path) + + +class TestTransformer(unittest.TestCase): + logger.info("Running tests for an imputation model Transformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Transformer") + model_save_name = "saved_transformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Transformer model + transformer = Transformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=256, + d_inner=128, + n_heads=4, + d_k=64, + d_v=64, + dropout=0.1, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_0_fit(self): + self.transformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_1_impute(self): + imputed_X = self.transformer.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"Transformer test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_2_parameters(self): + assert hasattr(self.transformer, "model") and self.transformer.model is not None + + assert ( + hasattr(self.transformer, "optimizer") + and self.transformer.optimizer is not None + ) + + assert hasattr(self.transformer, "best_loss") + self.assertNotEqual(self.transformer.best_loss, float("inf")) + + assert ( + hasattr(self.transformer, "best_model_dict") + and self.transformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.transformer) + + # save the trained model into file, and check if the path exists + self.transformer.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.transformer.load_model(saved_model_path) + + +class TestImputationBRITS(unittest.TestCase): + logger.info("Running tests for an imputation model BRITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "BRITS") + model_save_name = "saved_BRITS_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a BRITS model + brits = ImputationBRITS( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCHS, + saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/BRITS", + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_0_fit(self): + self.brits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_1_impute(self): + imputed_X = self.brits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"BRITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-brits") + def test_2_parameters(self): + assert hasattr(self.brits, "model") and self.brits.model is not None + + assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None + + assert hasattr(self.brits, "best_loss") + self.assertNotEqual(self.brits.best_loss, float("inf")) + + assert ( + hasattr(self.brits, "best_model_dict") + and self.brits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.brits) + + # save the trained model into file, and check if the path exists + self.brits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.brits.load_model(saved_model_path) + + +class TestLOCF(unittest.TestCase): + logger.info("Running tests for an imputation model LOCF...") + locf = LOCF(nan=0) + + @pytest.mark.xdist_group(name="imputation-locf") + def test_0_impute(self): + test_X_imputed = self.locf.impute(TEST_SET) + assert not np.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + test_X_imputed, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"LOCF test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-locf") + def test_1_parameters(self): + assert hasattr(self.locf, "nan") and self.locf.nan is not None + + +class TestClassificationBRITS(unittest.TestCase): + logger.info("Running tests for a classification model BRITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "BRITS") + model_save_name = "saved_BRITS_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a BRITS model + brits = BRITS( + DATA["n_steps"], + DATA["n_features"], + n_classes=DATA["n_classes"], + rnn_hidden_size=256, + epochs=EPOCHS, + saving_path=saving_path, + model_saving_strategy="better", + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="classification-brits") + def test_0_fit(self): + self.brits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-brits") + def test_1_classify(self): + predictions = self.brits.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-brits") + def test_2_parameters(self): + assert hasattr(self.brits, "model") and self.brits.model is not None + + assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None + + assert hasattr(self.brits, "best_loss") + self.assertNotEqual(self.brits.best_loss, float("inf")) + + assert ( + hasattr(self.brits, "best_model_dict") + and self.brits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-brits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.brits) + + # save the trained model into file, and check if the path exists + self.brits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.brits.load_model(saved_model_path) + + +class TestGRUD(unittest.TestCase): + logger.info("Running tests for a classification model GRUD...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "GRUD") + model_save_name = "saved_GRUD_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a GRUD model + grud = GRUD( + DATA["n_steps"], + DATA["n_features"], + n_classes=DATA["n_classes"], + rnn_hidden_size=256, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="classification-grud") + def test_0_fit(self): + self.grud.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-grud") + def test_1_classify(self): + predictions = self.grud.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-grud") + def test_2_parameters(self): + assert hasattr(self.grud, "model") and self.grud.model is not None + + assert hasattr(self.grud, "optimizer") and self.grud.optimizer is not None + + assert hasattr(self.grud, "best_loss") + self.assertNotEqual(self.grud.best_loss, float("inf")) + + assert ( + hasattr(self.grud, "best_model_dict") + and self.grud.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-grud") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.grud) + + # save the trained model into file, and check if the path exists + self.grud.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.grud.load_model(saved_model_path) + + +class TestRaindrop(unittest.TestCase): + logger.info("Running tests for a classification model Raindrop...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "Raindrop") + model_save_name = "saved_Raindrop_model.pypots" + + # initialize a Raindrop model + raindrop = Raindrop( + DATA["n_steps"], + DATA["n_features"], + DATA["n_classes"], + n_layers=2, + d_model=DATA["n_features"] * 4, + d_inner=256, + n_heads=2, + dropout=0.3, + d_static=0, + aggregation="mean", + sensor_wise_mask=False, + static=False, + epochs=EPOCHS, + saving_path=saving_path, + ) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_0_fit(self): + self.raindrop.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_1_classify(self): + predictions = self.raindrop.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_2_parameters(self): + assert hasattr(self.raindrop, "model") and self.raindrop.model is not None + + assert ( + hasattr(self.raindrop, "optimizer") and self.raindrop.optimizer is not None + ) + + assert hasattr(self.raindrop, "best_loss") + self.assertNotEqual(self.raindrop.best_loss, float("inf")) + + assert ( + hasattr(self.raindrop, "best_model_dict") + and self.raindrop.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.raindrop) + + # save the trained model into file, and check if the path exists + self.raindrop.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.raindrop.load_model(saved_model_path) + + +class TestCRLI(unittest.TestCase): + logger.info("Running tests for a clustering model CRLI...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "CRLI") + model_save_name = "saved_CRLI_model.pypots" + + # initialize an Adam optimizer + G_optimizer = Adam(lr=0.001, weight_decay=1e-5) + D_optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a CRLI model + crli = CRLI( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + n_clusters=DATA["n_classes"], + n_generator_layers=2, + rnn_hidden_size=128, + epochs=EPOCHS, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + ) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_0_fit(self): + self.crli.fit(TRAIN_SET) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_1_parameters(self): + assert hasattr(self.crli, "model") and self.crli.model is not None + + assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None + assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None + + assert hasattr(self.crli, "best_loss") + self.assertNotEqual(self.crli.best_loss, float("inf")) + + assert ( + hasattr(self.crli, "best_model_dict") + and self.crli.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_2_cluster(self): + clustering = self.crli.cluster(TEST_SET) + RI = cal_rand_index(clustering, DATA["test_y"]) + CP = cal_cluster_purity(clustering, DATA["test_y"]) + logger.info(f"RI: {RI}\nCP: {CP}") + + @pytest.mark.xdist_group(name="clustering-crli") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.crli) + + # save the trained model into file, and check if the path exists + self.crli.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.crli.load_model(saved_model_path) + + +class TestVaDER(unittest.TestCase): + logger.info("Running tests for a clustering model Transformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "VaDER") + model_save_name = "saved_VaDER_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a VaDER model + vader = VaDER( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + n_clusters=DATA["n_classes"], + rnn_hidden_size=64, + d_mu_stddev=5, + pretrain_epochs=20, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICES, + ) + + @pytest.mark.xdist_group(name="clustering-vader") + def test_0_fit(self): + self.vader.fit(TRAIN_SET) + + @pytest.mark.xdist_group(name="clustering-vader") + def test_1_cluster(self): + try: + clustering = self.vader.cluster(TEST_SET) + RI = cal_rand_index(clustering, DATA["test_y"]) + CP = cal_cluster_purity(clustering, DATA["test_y"]) + logger.info(f"RI: {RI}\nCP: {CP}") + except np.linalg.LinAlgError as e: + logger.error( + f"{e}\n" + "Got singular matrix, please try to retrain the model to fix this" + ) + + @pytest.mark.xdist_group(name="clustering-vader") + def test_2_parameters(self): + assert hasattr(self.vader, "model") and self.vader.model is not None + + assert hasattr(self.vader, "optimizer") and self.vader.optimizer is not None + + assert hasattr(self.vader, "best_loss") + self.assertNotEqual(self.vader.best_loss, float("inf")) + + assert ( + hasattr(self.vader, "best_model_dict") + and self.vader.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="clustering-vader") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.vader) + + # save the trained model into file, and check if the path exists + self.vader.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.vader.load_model(saved_model_path) + + +class TestBTTF(unittest.TestCase): + logger.info("Running tests for a forecasting model BTTF...") + + # initialize a BTTF model + bttf = BTTF( + n_steps=50, + n_features=10, + pred_step=10, + rank=10, + time_lags=[1, 2, 3, 10, 10 + 1, 10 + 2, 20, 20 + 1, 20 + 2], + burn_iter=5, + gibbs_iter=5, + multi_step=1, + ) + + @pytest.mark.xdist_group(name="forecasting-bttf") + def test_0_forecasting(self): + predictions = self.bttf.forecast(TEST_SET) + logger.info(f"prediction shape: {predictions.shape}") + mae = cal_mae(predictions, DATA["test_X_intact"][:, 50:]) + logger.info(f"prediction MAE: {mae}") + + +if __name__ == "__main__": + unittest.main()