Skip to content

Commit

Permalink
cover symbolic_trace by tracing several vision models in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang533 committed Mar 8, 2023
1 parent c9f9a09 commit 4813430
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/paddlefx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def create_node(self, op, target=None, args=None, kwargs=None, name=None):
return n

def output(self, result):
return self.create_node(op='output', target='output', args=(result,))
return self.create_node(op='output', target='output', args=result)

def _name(self, op):
if hasattr(op, '__name__'):
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefx/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
return _create_proxy(tracer, 'call_function', target, args, kwargs, method)
return _create_proxy(tracer, 'call_function', target, args, kwargs)

impl.__name__ = method
as_magic = f'__{method}__'
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefx/symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def wrapped(*args, **kwargs):
proxy = _find_proxy(args, kwargs)
if proxy is not None:
return_proxy = _create_proxy(
proxy.tracer, 'call_function', orig_fn, args, kwargs, orig_fn.__name__
proxy.tracer, 'call_function', orig_fn, args, kwargs
)
return return_proxy
return orig_fn(*args, **kwargs)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest

import paddle

import paddlefx


class TestFx(unittest.TestCase):
def setUp(self):
super().setUp()
self.models_to_track = [
(paddle.vision.models.resnet18(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.alexnet(), paddle.randn([2, 3, 224, 224])),
# DenseNet will failed on symbolic_trace, since it calls into _C_ops
# (paddle.vision.models.densenet121(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.googlenet(), paddle.randn([2, 3, 224, 224])),
(paddle.vision.models.inception_v3(), paddle.randn([2, 3, 299, 299])),
(paddle.vision.models.mobilenet_v2(), paddle.randn([2, 3, 224, 224])),
]

def tearDown(self):
super().tearDown()

def test_trace(self):
for model, input_example in self.models_to_track:
traced_model = paddlefx.symbolic_trace(model)
paddle.seed(1234)
orig_output = model(input_example)
paddle.seed(1234)
traced_output = traced_model(input_example)

# some nets, e.g.: googlenet, return a list of tensors
orig_ret_list = (
list(orig_output)
if isinstance(orig_output, (list, tuple))
else [orig_output]
)
traced_ret_list = (
[*traced_output]
if isinstance(traced_output, (list, tuple))
else [traced_output]
)

self.assertEqual(
len(orig_ret_list),
len(traced_ret_list),
f"model: {type(model).__name__} failed",
)

for i, o in enumerate(traced_ret_list):
self.assertTrue(
paddle.allclose(orig_ret_list[i], traced_ret_list[i]),
f"model: {type(model).__name__} failed",
)

0 comments on commit 4813430

Please sign in to comment.