Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable models to run on multiple CUDA devices #125

Merged
merged 21 commits into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
80a792a
feat: enable SAITS to run on multiple CUDA devices;
WenjieDu May 18, 2023
451bd05
fix: put raise error under else;
WenjieDu May 18, 2023
24bfbe7
fix: put isinstance at the head of the if statement;
WenjieDu May 18, 2023
0c4254a
fix: add func_send_data_to_given_device();
WenjieDu May 18, 2023
d23fe98
docs: add the error message, work not finished;
WenjieDu May 18, 2023
4bca033
Merge branch 'dev' into enable_multiGPU_training
WenjieDu May 19, 2023
48dcc96
refactor: add funcs _setup_device() and _setup_path();
WenjieDu May 19, 2023
b58d934
feat: balance workload between multiple device in func _send_data_to_…
WenjieDu May 19, 2023
fda70a8
feat: enable all models to run on multiple devices;
WenjieDu May 19, 2023
c684303
fix: device could be type of list;
WenjieDu May 19, 2023
ae4e8b9
fix: errors when running on multiple devices;
WenjieDu May 19, 2023
feb31ce
feat: add python 3.8 to testing_ci workflow;
WenjieDu May 19, 2023
2fdfa06
Merge branch 'dev' into enable_multiGPU_training
WenjieDu May 19, 2023
895a47f
fix: error of tensors not on the same device while parallelly trainin…
WenjieDu May 20, 2023
c7cb426
feat: make BaseNNTask classes not inherit from BaseTask classes;
WenjieDu May 20, 2023
31d0698
docs: update the docstring of parameter `device`;
WenjieDu May 20, 2023
7112314
feat: add testing cases for training on multiple GPUs;
WenjieDu May 20, 2023
a49676f
Merge pull request #122 from WenjieDu/enable_multiGPU_training
WenjieDu May 20, 2023
d0625ca
skip multi-gpu test if not multi-gpu host
MaciejSkrabski May 20, 2023
ca47365
use 2 as required minimum cuda devices for multi-gpu
MaciejSkrabski May 20, 2023
fe5a358
Merge pull request #124 from MaciejSkrabski/update-tests
WenjieDu May 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
python-version: ["3.7", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v3
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/testing_daily.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:

- name: Test with pytest
run: |
coverage run --source=pypots -m pytest
coverage run --source=pypots -m pytest --ignore tests/test_training_on_multi_gpus.py
# ignore the test_training_on_multi_gpus.py because it requires multiple GPUs which are not available on GitHub Actions

- name: Generate the LCOV report
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ GPU Acceleration
Neural-network models in PyPOTS are implemented in PyTorch. So far we only support CUDA-enabled GPUs for GPU acceleration.
If you have a CUDA device, you can install PyTorch with GPU support to accelerate the training and inference of neural-network models.
After that, you can set the ``device`` argument to ``"cuda"`` when initializing the model to enable GPU acceleration.
If you don't specify ``device``, PyPOTS will automatically detect and use the first CUDA device (i.e. ``cuda:0``) if multiple CUDA devices are available.
If you don't specify ``device``, PyPOTS will automatically detect and use the default CUDA device if multiple CUDA devices are available.

CPU Acceleration
****************
Expand Down
121 changes: 98 additions & 23 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import os
from abc import ABC
from datetime import datetime
from typing import Optional, Union

import torch
Expand All @@ -22,9 +23,11 @@ class BaseModel(ABC):
Parameters
----------
device :
The device for the model to run on.
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.

saving_path :
Expand Down Expand Up @@ -56,7 +59,7 @@ class BaseModel(ABC):

def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
):
Expand All @@ -73,28 +76,63 @@ def __init__(
self.summary_writer = None

# set up the device for model running below
self._setup_device(device)

# set up saving_path to save the trained model and training logs
self._setup_path(saving_path)

def _setup_device(self, device):
if device is None:
# if it is None, then
self.device = torch.device(
"cuda"
if torch.cuda.is_available() and torch.cuda.device_count() > 0
else "cpu"
)
# if it is None, then use the first cuda device if cuda is available, otherwise use cpu
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
logger.info(f"No given device, using default device: {self.device}")
else:
if isinstance(device, str):
self.device = torch.device(device)
self.device = torch.device(device.lower())
elif isinstance(device, torch.device):
self.device = device
elif isinstance(device, list):
# parallely training on multiple CUDA devices
device_list = []
for idx, d in enumerate(device):
if isinstance(d, str):
d = d.lower()
assert (
"cuda" in d
), "The feature of training on multiple devices currently only support CUDA devices."
device_list.append(torch.device(d))
elif isinstance(d, torch.device):
assert (
"cuda" in d.type
), "The feature of training on multiple devices currently only support CUDA devices."
device_list.append(d)
else:
raise TypeError(
f"Devices in the list should be str or torch.device, "
f"but the device with index {idx} is {type(d)}."
)
if len(device_list) > 1:
self.device = device_list
else:
self.device = device_list[0]
else:
raise TypeError(
f"device should be str or torch.device, but got {type(device)}"
f"device should be str/torch.device/a list containing str or torch.device, but got {type(device)}"
)

# set up saving_path to save the trained model and training logs
if isinstance(saving_path, str):
from datetime import datetime
# check CUDA availability if using CUDA
if (isinstance(self.device, list) and "cuda" in self.device[0].type) or (
isinstance(self.device, torch.device) and "cuda" in self.device.type
):
assert (
torch.cuda.is_available() and torch.cuda.device_count() > 0
), "You are trying to use CUDA for model training, but CUDA is not available in your environment."

def _setup_path(self, saving_path):
if isinstance(saving_path, str):
# get the current time to append to saving_path,
# so you can use the same saving_path to run multiple times
# and also be aware of when they were run
Expand All @@ -109,9 +147,35 @@ def __init__(
tb_saving_path,
filename_suffix=".pypots",
)
logger.info(f"Model files will be saved to {self.saving_path}")
logger.info(f"Tensorboard file will be saved to {tb_saving_path}")
else:
logger.info(
"saving_path not given. Model files and tensorboard file will not be saved."
)

def _send_model_to_given_device(self):
if isinstance(self.device, list):
# parallely training on multiple devices
self.model = torch.nn.DataParallel(self.model, device_ids=self.device)
self.model = self.model.cuda()
logger.info(
f"Model has been allocated to the given multiple devices: {self.device}"
)
else:
self.model = self.model.to(self.device)

def _send_data_to_given_device(self, data):
if isinstance(self.device, torch.device): # single device
data = map(lambda x: x.to(self.device), data)
else: # parallely training on multiple devices

logger.info(f"the trained model will be saved to {self.saving_path}")
logger.info(f"the tensorboard file will be saved to {tb_saving_path}")
# randomly choose one device to balance the workload
# device = np.random.choice(self.device)

data = map(lambda x: x.cuda(), data)

return data

def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None:
"""Saving training logs into the tensorboard file specified by the given path `tb_file_saving_path`.
Expand All @@ -135,7 +199,7 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None
# save all items containing "loss" or "error" in the name
# WDU: may enable customization keywords in the future
if ("loss" in item_name) or ("error" in item_name):
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step)
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step)

def save_model(
self,
Expand Down Expand Up @@ -175,7 +239,11 @@ def save_model(
logger.error(f"File {saving_path} exists. Saving operation aborted.")
try:
create_dir_if_not_exist(saving_dir)
torch.save(self.model, saving_path)
if isinstance(self.device, list):
# to save a DataParallel model generically, save the model.module.state_dict()
torch.save(self.model.module, saving_path)
else:
torch.save(self.model, saving_path)
logger.info(f"Saved the model to {saving_path}.")
except Exception as e:
raise RuntimeError(
Expand Down Expand Up @@ -226,9 +294,15 @@ def load_model(self, model_path: str) -> None:
assert os.path.exists(model_path), f"Model file {model_path} does not exist."

try:
loaded_model = torch.load(model_path, map_location=self.device)
if isinstance(self.device, torch.device):
loaded_model = torch.load(model_path, map_location=self.device)
else:
loaded_model = torch.load(model_path)
if isinstance(loaded_model, torch.nn.Module):
self.model.load_state_dict(loaded_model.state_dict())
if isinstance(self.device, torch.device):
self.model.load_state_dict(loaded_model.state_dict())
else:
self.model.module.load_state_dict(loaded_model.state_dict())
else:
self.model = loaded_model.model
except Exception as e:
Expand Down Expand Up @@ -257,9 +331,11 @@ class BaseNNModel(BaseModel):
`0` means data loading will be in the main process, i.e. there won't be subprocesses.

device :
The device for the model to run on.
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.

saving_path :
Expand Down Expand Up @@ -301,12 +377,11 @@ def __init__(
epochs: int,
patience: int,
num_workers: int = 0,
device: Optional[Union[str, torch.device]] = None,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
):
BaseModel.__init__(
self,
super().__init__(
device,
saving_path,
model_saving_strategy,
Expand Down
Loading