Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add tests for GraphConverterWithShape (#3951)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaleid-liner authored Jul 20, 2021
1 parent 403195f commit bb39748
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 40 deletions.
19 changes: 19 additions & 0 deletions test/ut/retiarii/convert_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape


class ConvertMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
return model_ir


class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
return model_ir
11 changes: 7 additions & 4 deletions test/ut/retiarii/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
Expand Down Expand Up @@ -48,7 +49,7 @@ def forward(self, input):
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)

class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -61,8 +62,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)

exec_vars = {}
Expand Down Expand Up @@ -579,3 +579,6 @@ def test_alexnet(self):
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()

class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
24 changes: 14 additions & 10 deletions test/ut/retiarii/test_convert_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script

# following pytorch v1.7.1

class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x, y, z):
out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))

def test_basic_addr(self):
class SimpleOp(nn.Module):
def forward(self, x, y, z):
Expand All @@ -204,7 +204,7 @@ def forward(self, x, y):
out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))

def test_basic_angle(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -229,7 +229,7 @@ def forward(self, x):
o4 = x.argmin(dim=1, keepdim=True)
return out1, out2, out3, out4, out5, o1, o2, o3, o4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))

def test_basic_argsort(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -241,7 +241,7 @@ def forward(self, x):
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))

# skip backward(gradient=None, retain_graph=None, create_graph=False)

def test_basic_bernoulli(self):
class SimpleOp(nn.Module):
def forward(self, x):
Expand All @@ -261,7 +261,7 @@ def forward(self, x, y):
out4 = x.bincount(weights=y, minlength=2)
return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))

def test_basic_bitwise(self):
class SimpleOp(nn.Module):
def forward(self, x, y):
Expand All @@ -279,4 +279,8 @@ class SimpleOp(nn.Module):
def forward(self, x):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
self.checkExportImport(SimpleOp(), (torch.randn(4), ))


class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
11 changes: 7 additions & 4 deletions test/ut/retiarii/test_convert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class TestModels(unittest.TestCase):

class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -27,8 +28,7 @@ def _match_state_dict(current_values, expected_format):
return result

def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -89,3 +89,6 @@ def forward(self, x: List[torch.Tensor]):
model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))

class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass
37 changes: 20 additions & 17 deletions test/ut/retiarii/test_convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import torchvision

import nni.retiarii.nn.pytorch as nn
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

# following pytorch v1.7.1


class TestOperators(unittest.TestCase):
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -34,8 +35,7 @@ def _match_state_dict(current_values, expected_format):
return result

def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
#print(model_code)

Expand Down Expand Up @@ -1042,7 +1042,7 @@ def forward(self, x):

x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_batchnorm(self):
class SimpleOp(nn.Module):
Expand All @@ -1056,7 +1056,7 @@ def forward(self, x):

x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_batchnorm_1d(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1084,7 +1084,7 @@ def forward(self, x):

x = torch.ones(20, 16, 50, 40, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))

def test_conv_onnx_irv4_opset8(self):
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_maxpool_dilations(self):
class SimpleOp(nn.Module):
Expand All @@ -1143,7 +1143,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_avg_pool2d(self):
class SimpleOp(nn.Module):
Expand All @@ -1157,7 +1157,7 @@ def forward(self, x):

x = torch.randn(20, 16, 50, 32)
self.checkExportImport(SimpleOp(), (x, ))

@unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"')
def test_basic_maxpool_indices(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_elu(self):
class SimpleOp(nn.Module):
Expand All @@ -1214,7 +1214,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_selu(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def forward(self, x):

x = torch.randn(128, 128, 1, 1, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))

def test_embedding_bags(self):
class SimpleOp(nn.Module):
def __init__(self):
Expand All @@ -1288,7 +1288,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_prelu(self):
class SimpleOp(nn.Module):
Expand All @@ -1302,7 +1302,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_log_sigmoid(self):
class SimpleOp(nn.Module):
Expand All @@ -1316,7 +1316,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, ))


def test_basic_linear(self):
class SimpleOp(nn.Module):
Expand Down Expand Up @@ -1385,4 +1385,7 @@ def forward(self, x):
return out

x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, ))
self.checkExportImport(SimpleOp(), (x, ))

class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass
13 changes: 8 additions & 5 deletions test/ut/retiarii/test_convert_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin

class TestPytorch(unittest.TestCase):

class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
Expand All @@ -32,8 +33,7 @@ def _match_state_dict(current_values, expected_format):
return result

def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)

Expand Down Expand Up @@ -1230,4 +1230,7 @@ def forward(self, input):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)

x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
self.run_test(SizeModel(10, 5), (x, ))

class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
Loading

0 comments on commit bb39748

Please sign in to comment.