Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add tests for GraphConverterWithShape #3951

Merged
merged 1 commit into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/ut/retiarii/convert_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape


class ConvertMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
return model_ir


class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
return model_ir
11 changes: 7 additions & 4 deletions test/ut/retiarii/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
Expand Down Expand Up @@ -48,7 +49,7 @@ def forward(self, input):
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)

class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -61,8 +62,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)

exec_vars = {}
Expand Down Expand Up @@ -579,3 +579,6 @@ def test_alexnet(self):
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()

class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
24 changes: 14 additions & 10 deletions test/ut/retiarii/test_convert_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script

# following pytorch v1.7.1

class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x, y, z):
out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))

def test_basic_addr(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
Expand All @@ -204,7 +204,7 @@ def forward(self, x, y):
out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))

def test_basic_angle(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -229,7 +229,7 @@ def forward(self, x):
o4 = x.argmin(dim=1, keepdim=True)
return out1, out2, out3, out4, out5, o1, o2, o3, o4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))

def test_basic_argsort(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -241,7 +241,7 @@ def forward(self, x):
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))

# skip backward(gradient=None, retain_graph=None, create_graph=False)

def test_basic_bernoulli(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -261,7 +261,7 @@ def forward(self, x, y):
out4 = x.bincount(weights=y, minlength=2)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))

def test_basic_bitwise(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
Expand All @@ -279,4 +279,8 @@ class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
self.checkExportImport(SimpleOp(), (torch.randn(4), ))


class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
11 changes: 7 additions & 4 deletions test/ut/retiarii/test_convert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class TestModels(unittest.TestCase):

class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format):
return result

def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -89,3 +89,6 @@ def forward(self, x: List[torch.Tensor]):
model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))

class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass
37 changes: 20 additions & 17 deletions test/ut/retiarii/test_convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import torchvision

import nni.retiarii.nn.pytorch as nn
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

# following pytorch v1.7.1


class TestOperators(unittest.TestCase):
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -34,8 +35,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
#print(model_code)

Expand Down Expand Up @@ -1042,7 +1042,7 @@ def forward(self, x):

x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_batchnorm(self):
class SimpleOp(nn.Module):
Expand All @@ -1056,7 +1056,7 @@ def forward(self, x):

x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_batchnorm_1d(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1084,7 +1084,7 @@ def forward(self, x):

x = torch.ones(20, 16, 50, 40, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))

def test_conv_onnx_irv4_opset8(self):
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_maxpool_dilations(self):
class SimpleOp(nn.Module):
Expand All @@ -1143,7 +1143,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_avg_pool2d(self):
class SimpleOp(nn.Module):
Expand All @@ -1157,7 +1157,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50, 32)
self.checkExportImport(SimpleOp(), (x, ))

@unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"')
def test_basic_maxpool_indices(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_elu(self):
class SimpleOp(nn.Module):
Expand All @@ -1214,7 +1214,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_selu(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def forward(self, x):

x = torch.randn(128, 128, 1, 1, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))

def test_embedding_bags(self):
class SimpleOp(nn.Module):
def __init__(self):
Expand All @@ -1288,7 +1288,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_prelu(self):
class SimpleOp(nn.Module):
Expand All @@ -1302,7 +1302,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_log_sigmoid(self):
class SimpleOp(nn.Module):
Expand All @@ -1316,7 +1316,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_linear(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1385,4 +1385,7 @@ def forward(self, x):
return out

x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, ))
self.checkExportImport(SimpleOp(), (x, ))

class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass
13 changes: 8 additions & 5 deletions test/ut/retiarii/test_convert_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class TestPytorch(unittest.TestCase):

class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -32,8 +33,7 @@ def _match_state_dict(current_values, expected_format):
return result

def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -1230,4 +1230,7 @@ def forward(self, input):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)

x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
self.run_test(SizeModel(10, 5), (x, ))

class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
Loading