diff --git a/mmedit/models/extractors/__init__.py b/mmedit/models/extractors/__init__.py index 45dc7abf87..75f0a114f9 100644 --- a/mmedit/models/extractors/__init__.py +++ b/mmedit/models/extractors/__init__.py @@ -1,4 +1,5 @@ -from .feedback_hour_glass import FeedbackHourglass, Hourglass +from .feedback_hour_glass import (FeedbackHourglass, Hourglass, + reduce_to_five_heatmaps) from .lte import LTE -__all__ = ['LTE', 'Hourglass', 'FeedbackHourglass'] +__all__ = ['LTE', 'Hourglass', 'FeedbackHourglass', 'reduce_to_five_heatmaps'] diff --git a/mmedit/models/extractors/feedback_hour_glass.py b/mmedit/models/extractors/feedback_hour_glass.py index 43d609c6df..75aaf3a448 100644 --- a/mmedit/models/extractors/feedback_hour_glass.py +++ b/mmedit/models/extractors/feedback_hour_glass.py @@ -157,3 +157,55 @@ def forward(self, x, last_hidden=None): heatmap = self.last(feature[:, :self.mid_channels]) # first half feedback = feature[:, self.mid_channels:] # second half return heatmap, feedback + + +def reduce_to_five_heatmaps(ori_heatmap, detach): + """Reduce facial landmark heatmaps to 5 heatmaps. + + DIC realizes facial SR with the help of key points of the face. + The number of key points in datasets are different from each other. + This function reduces the input heatmaps into 5 heatmaps: + left eye + right eye + nose + mouse + face silhouette + + Args: + ori_heatmap (Tensor): Input heatmap tensor. (B, N, 32, 32). + detach (bool): Detached from the current tensor or not. + + returns: + Tensor: New heatmap tensor. (B, 5, 32, 32). + """ + + heatmap = ori_heatmap.clone() + max_heat = heatmap.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + max_heat = max_heat.clamp_min_(0.05) + heatmap /= max_heat + if heatmap.size(1) == 5: + return heatmap.detach() if detach else heatmap + elif heatmap.size(1) == 68: + new_heatmap = torch.zeros_like(heatmap[:, :5]) + new_heatmap[:, 0] = heatmap[:, 36:42].sum(1) # left eye + new_heatmap[:, 1] = heatmap[:, 42:48].sum(1) # right eye + new_heatmap[:, 2] = heatmap[:, 27:36].sum(1) # nose + new_heatmap[:, 3] = heatmap[:, 48:68].sum(1) # mouse + new_heatmap[:, 4] = heatmap[:, :27].sum(1) # face silhouette + return new_heatmap.detach() if detach else new_heatmap + elif heatmap.size(1) == 194: # Helen + new_heatmap = torch.zeros_like(heatmap[:, :5]) + tmp_id = torch.cat((torch.arange(134, 153), torch.arange(174, 193))) + new_heatmap[:, 0] = heatmap[:, tmp_id].sum(1) # left eye + tmp_id = torch.cat((torch.arange(114, 133), torch.arange(154, 173))) + new_heatmap[:, 1] = heatmap[:, tmp_id].sum(1) # right eye + tmp_id = torch.arange(41, 57) + new_heatmap[:, 2] = heatmap[:, tmp_id].sum(1) # nose + tmp_id = torch.arange(58, 113) + new_heatmap[:, 3] = heatmap[:, tmp_id].sum(1) # mouse + tmp_id = torch.arange(0, 40) + new_heatmap[:, 4] = heatmap[:, tmp_id].sum(1) # face silhouette + return new_heatmap.detach() if detach else new_heatmap + else: + raise NotImplementedError( + f'Face landmark number {heatmap.size(1)} not implemented!') diff --git a/tests/test_extractors.py b/tests/test_extractors.py index 8591ceccd0..e35b07c3ff 100644 --- a/tests/test_extractors.py +++ b/tests/test_extractors.py @@ -3,7 +3,8 @@ from mmedit.models import build_component from mmedit.models.extractors import Hourglass -from mmedit.models.extractors.feedback_hour_glass import ResBlock +from mmedit.models.extractors.feedback_hour_glass import ( + ResBlock, reduce_to_five_heatmaps) def test_lte(): @@ -64,3 +65,27 @@ def test_feedback_hour_glass(): heatmap, last_hidden = fhg.forward(x, last_hidden) assert heatmap.shape == (2, 20, 16, 16) assert last_hidden.shape == (2, 16, 16, 16) + + +def test_reduce_to_five_heatmaps(): + heatmap = torch.rand((2, 5, 64, 64)) + new_heatmap = reduce_to_five_heatmaps(heatmap, False) + assert new_heatmap.shape == (2, 5, 64, 64) + new_heatmap = reduce_to_five_heatmaps(heatmap, True) + assert new_heatmap.shape == (2, 5, 64, 64) + + heatmap = torch.rand((2, 68, 64, 64)) + new_heatmap = reduce_to_five_heatmaps(heatmap, False) + assert new_heatmap.shape == (2, 5, 64, 64) + new_heatmap = reduce_to_five_heatmaps(heatmap, True) + assert new_heatmap.shape == (2, 5, 64, 64) + + heatmap = torch.rand((2, 194, 64, 64)) + new_heatmap = reduce_to_five_heatmaps(heatmap, False) + assert new_heatmap.shape == (2, 5, 64, 64) + new_heatmap = reduce_to_five_heatmaps(heatmap, True) + assert new_heatmap.shape == (2, 5, 64, 64) + + with pytest.raises(NotImplementedError): + heatmap = torch.rand((2, 12, 64, 64)) + reduce_to_five_heatmaps(heatmap, False)