Skip to content

Commit

Permalink
device-agnostic code (open-mmlab#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
jin-s13 authored Jul 24, 2020
1 parent 0375a21 commit d8e6741
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 20 deletions.
3 changes: 1 addition & 2 deletions mmpose/models/detectors/bottom_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def forward_test(self, img, img_metas, **kwargs):
aggregated_heatmaps = None
tags_list = []
for idx, s in enumerate(sorted(test_scale_factor, reverse=True)):
image_resized = aug_data[idx]
image_resized = image_resized.cuda(non_blocking=True)
image_resized = aug_data[idx].to(img.device)

outputs = self.backbone(image_resized)
outputs = self.keypoint_head(outputs)
Expand Down
19 changes: 11 additions & 8 deletions mmpose/models/losses/multi_loss_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
from ..registry import LOSSES


def _make_input(t, requires_grad=False, need_cuda=True):
def _make_input(t, requires_grad=False, device=torch.device('cpu')):
"""Make zero inputs for AE loss.
Args:
t (torch.Tensor): input
requires_grad (bool): Option to use requires_grad.
need_cuda (bool): Opthin to use cuda.
device: torch device
Returns:
inp (torch.Tensor): zero input.
"""
inp = torch.autograd.Variable(t, requires_grad=requires_grad)
inp = inp.sum()
if need_cuda:
inp = inp.cuda()
inp = inp.to(device)
return inp


Expand Down Expand Up @@ -73,7 +72,7 @@ def singleTagLoss(self, pred_tag, joints):
max_num_people: M
num_keypoints: K
Args:
tags(torch.Tensor[(KxHxW)x1]): tag of output for one image.
pred_tag(torch.Tensor[(KxHxW)x1]): tag of output for one image.
joints(torch.Tensor[MxKx2]): joints information for one image.
"""
tags = []
Expand All @@ -91,10 +90,14 @@ def singleTagLoss(self, pred_tag, joints):

num_tags = len(tags)
if num_tags == 0:
return (_make_input(torch.zeros(1).float()),
_make_input(torch.zeros(1).float()))
return (_make_input(
torch.zeros(1).float(), device=pred_tag.device),
_make_input(
torch.zeros(1).float(), device=pred_tag.device))
elif num_tags == 1:
return (_make_input(torch.zeros(1).float()), pull / (num_tags))
return (_make_input(
torch.zeros(1).float(),
device=pred_tag.device), pull / (num_tags))

tags = torch.stack(tags)

Expand Down
5 changes: 2 additions & 3 deletions tests/test_model/test_bottom_up_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def test_bottomup_forward():

# Test forward test
with torch.no_grad():
detector = detector.cuda()
_ = detector.forward(
imgs.cuda(), img_metas=img_metas, return_loss=False)
detector = detector
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)


def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
Expand Down
8 changes: 2 additions & 6 deletions tests/test_model/test_top_down_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def test_topdown_forward():

# Test forward test
with torch.no_grad():
detector = detector.cuda()
_ = detector.forward(
imgs.cuda(), img_metas=img_metas, return_loss=False)
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)

# flip test
model_cfg = dict(
Expand Down Expand Up @@ -84,9 +82,7 @@ def test_topdown_forward():

# Test forward test
with torch.no_grad():
detector = detector.cuda()
_ = detector.forward(
imgs.cuda(), img_metas=img_metas, return_loss=False)
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)


def _demo_mm_inputs(input_shape=(1, 3, 256, 256)):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipelines/test_top_down_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _box2cs(box, image_size):
return center, scale


def test_pipeline():
def test_top_down_pipeline():
# test loading
data_prefix = 'tests/data/'
ann_file = osp.join(data_prefix, 'test_coco.json')
Expand Down

0 comments on commit d8e6741

Please sign in to comment.