Skip to content

Commit

Permalink
Support CPU Train/Inference (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeowZheng committed Jan 28, 2022
1 parent 198748d commit 0523c58
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
10 changes: 10 additions & 0 deletions docs/en/tutorials/1_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ We provide testing scripts for evaluating an existing model on the whole dataset
The following testing environments are supported:

- single GPU
- CPU
- single node multiple GPUs
- multiple nodes

Expand All @@ -55,6 +56,15 @@ python tools/test.py \
[--out-dir ${OUTPUT_DIRECTORY}] \
[--show-dir ${VISUALIZATION_DIRECTORY}]

# CPU: disable GPUs and run single-gpu testing script
export CUDA_VISIBLE_DEVICES=-1
python tools/test.py \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
[--out ${RESULT_FILE}] \
[--eval ${EVAL_METRICS}] \
[--show]

# multi-gpu testing
bash tools/dist_test.sh \
${CONFIG_FILE} \
Expand Down
14 changes: 14 additions & 0 deletions docs/en/tutorials/2_finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ Difference between `resume-from` and `load-from`:
It is usually used for resuming the training process that is interrupted accidentally.
`load-from` only loads the model weights and the training iteration starts from 0. It is usually used for finetuning.

### Training on CPU

The process of training on the CPU is consistent with single GPU training. We just need to disable GPUs before the training process.

```shell
export CUDA_VISIBLE_DEVICES=-1
```

And then run the script [above](#training-on-a-single-GPU).

```{note}
We do not recommend users to use CPU for training because it is too slow. We support this feature to allow users to debug on machines without GPU for convenience.
```

## Training on multiple GPUs

MMFlow implements **distributed** training with `MMDistributedDataParallel`.
Expand Down
16 changes: 14 additions & 2 deletions mmflow/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import warnings
from typing import Optional, Sequence, Union

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook,
build_optimizer, build_runner, get_dist_info)
from mmcv.utils import Config, build_from_cfg

from mmflow import digit_version
from mmflow.core import DistEvalHook, EvalHook
from mmflow.datasets import build_dataloader, build_dataset
from mmflow.utils import find_latest_checkpoint, get_root_logger
Expand Down Expand Up @@ -114,8 +117,17 @@ def train_model(model: Module,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# SyncBN is not support for DP
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'

model = MMDataParallel(model, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down
15 changes: 14 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmcv.utils.logging import print_log

from mmflow import digit_version
from mmflow.apis import multi_gpu_test, single_gpu_test
from mmflow.core import online_evaluation
from mmflow.datasets import build_dataloader, build_dataset
Expand Down Expand Up @@ -143,8 +146,18 @@ def main():
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=[0])
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = MMDataParallel(model, device_ids=cfg.gpu_ids)

else:
model = MMDistributedDataParallel(
model.cuda(),
Expand Down
1 change: 1 addition & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def main():

model = build_flow_estimator(cfg.model)
model.init_weights()

logger.info(model)
if cfg.data.train_dataloader.get('sample_ratio') is None:
# build_dataset will concat the list of dataset
Expand Down

0 comments on commit 0523c58

Please sign in to comment.