Skip to content

Commit

Permalink
revise mock
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Aug 21, 2023
1 parent 3cba9d3 commit 1b38e52
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions tests/test_models/test_archs/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,25 @@
register_all_modules()


def mock_torch_device():
if digit_version(TORCH_VERSION) < digit_version('2.0.0'):

def mock_fn(device):
_orig_device_fn = torch.device
if device == 'meta':
return MagicMock()
else:
return _orig_device_fn(device)

torch.device = mock_fn


class TestWrapper(TestCase):

def test_build(self):
# mock SiLU
if digit_version(TORCH_VERSION) <= digit_version('1.6.0'):
from mmagic.models.editors.ddpm.denoising_unet import SiLU
torch.nn.SiLU = SiLU

# mock torch.device for pt<2.0.0
mock_torch_device()
_orig_device_fn = torch.device
if digit_version(TORCH_VERSION) < digit_version('2.0.0'):

def mock_fn(device):
if device == 'meta':
return _orig_device_fn('cpu')
else:
return _orig_device_fn(device)

torch.device = mock_fn

# 1. test from config
model = MODELS.build(
Expand Down Expand Up @@ -121,3 +118,5 @@ def test_build(self):
model.registrer_buffer('buffer', 123)
called_args, _ = register_buffer_mock.call_args
self.assertEqual(called_args, ('buffer', 123))

torch.device = _orig_device_fn

0 comments on commit 1b38e52

Please sign in to comment.