Skip to content

Commit

Permalink
[Feature]: support CPU training with MMDataParallel (#972)
Browse files Browse the repository at this point in the history
* support for CPU training

* Update .pre-commit-config.yaml

* Update data_parallel.py
  • Loading branch information
wangruohui authored Apr 24, 2021
1 parent 841a078 commit 5a99f58
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmcv/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train_step(self, *inputs, **kwargs):
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.train_step(*inputs, **kwargs)
return self.module.train_step(*inputs[0], **kwargs[0])

assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
Expand All @@ -71,7 +71,7 @@ def val_step(self, *inputs, **kwargs):
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.val_step(*inputs, **kwargs)
return self.module.val_step(*inputs[0], **kwargs[0])

assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
Expand Down

0 comments on commit 5a99f58

Please sign in to comment.