Skip to content

Commit

Permalink
Add TVMC Frontend for PaddlePaddle (apache#9083)
Browse files Browse the repository at this point in the history
* fix some problems for matmul

* fix some problems for matmul

* add alpha parameter for matmul

* remove unnecessary condition

* add TranslatedLayer which support model loaded by jit.load

* add mul operator support

* Add padding mode support for conv/pool2d

* support 4 two-tuples

* add paddle test case

* add paddle conv2d  case

* update test_forward.py

* fix paddle convert_matmul

* add paddle multiply and matmul op test case

* add test case and fix bug

* delete import pandas

* add paddlepaddle tests

* modify the variable name of convert_reshape

* formatting

* formatting

* use black to format python code

* pylint check

* Remove fluid api

* black format

* Add Paddle Frontend for TVMC

* refine code format

* add test case for tvmc

* fix pylint check

* gen_requirements add paddlepaddle

* Trigger CI

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: wjj19950828 <wjjisloser@163.com>
Co-authored-by: heliqi <1101791222@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
5 people authored and ylc committed Jan 7, 2022
1 parent 647ec47 commit f5f34f6
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
"future", # Hidden dependency of torch.
"onnx",
"onnxruntime",
"paddlepaddle",
"tensorflow",
"tflite",
"torch",
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,38 @@ def load(self, path, shape_dict=None, **kwargs):
return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)


class PaddleFrontend(Frontend):
"""PaddlePaddle frontend for TVMC"""

@staticmethod
def name():
return "paddle"

@staticmethod
def suffixes():
return ["pdmodel", "pdiparams"]

def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import paddle

paddle.enable_static()
paddle.disable_signal_handler()

# pylint: disable=E1101
exe = paddle.static.Executor(paddle.CPUPlace())
prog, _, _ = paddle.static.load_inference_model(path, exe)

return relay.frontend.from_paddle(prog, shape_dict=shape_dict, **kwargs)


ALL_FRONTENDS = [
KerasFrontend,
OnnxFrontend,
TensorflowFrontend,
TFLiteFrontend,
PyTorchFrontend,
PaddleFrontend,
]


Expand Down
14 changes: 13 additions & 1 deletion tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir):
model_tar_name = os.path.basename(model_url)
model_path = download_testdata(model_url, model_tar_name, module=["tvmc"])

if model_path.endswith("tgz") or model_path.endswith("gz"):
if model_path.endswith("tgz") or model_path.endswith("gz") or model_path.endswith("tar"):
tar = tarfile.open(model_path)
tar.extractall(path=temp_dir)
tar.close()
Expand Down Expand Up @@ -137,6 +137,18 @@ def onnx_resnet50():
return model_file


@pytest.fixture(scope="session")
def paddle_resnet50(tmpdir_factory):
base_url = "https://bj.bcebos.com/x2paddle/models"
model_url = "paddle_resnet50.tar"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"paddle_resnet50/model",
temp_dir=tmpdir_factory.mktemp("data"),
)
return model_file


@pytest.fixture(scope="session")
def onnx_mnist():
base_url = "https://github.com/onnx/models/raw/master/vision/classification/mnist/model"
Expand Down
78 changes: 78 additions & 0 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,84 @@ def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50):
assert os.path.exists(dumps_path)


def verify_compile_paddle_module(model, shape_dict=None):
pytest.importorskip("paddle")
tvmc_model = tvmc.load(model, "paddle", shape_dict=shape_dict)
tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW")
dumps_path = tvmc_package.package_path + ".ll"

# check for output types
assert type(tvmc_package) is TVMCPackage
assert type(tvmc_package.graph) is str
assert type(tvmc_package.lib_path) is str
assert type(tvmc_package.params) is bytearray
assert os.path.exists(dumps_path)


def test_compile_paddle_module(paddle_resnet50):
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")
# Check default compilation.
verify_compile_paddle_module(paddle_resnet50)
# Check with manual shape override
shape_string = "inputs:[1,3,224,224]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
verify_compile_paddle_module(paddle_resnet50, shape_dict)


# This test will be skipped if the AArch64 cross-compilation toolchain is not installed.
@pytest.mark.skipif(
not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed"
)
def test_cross_compile_aarch64_paddle_module(paddle_resnet50):
# some CI environments wont offer paddle, so skip in case it is not present
pytest.importorskip("paddle")

tvmc_model = tvmc.load(paddle_resnet50, "paddle")
tvmc_package = tvmc.compile(
tvmc_model,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
cross="aarch64-linux-gnu-gcc",
)
dumps_path = tvmc_package.package_path + ".asm"

# check for output types
assert type(tvmc_package) is TVMCPackage
assert type(tvmc_package.graph) is str
assert type(tvmc_package.lib_path) is str
assert type(tvmc_package.params) is bytearray
assert os.path.exists(dumps_path)


# This test will be skipped if the AArch64 cross-compilation toolchain is not installed.
@pytest.mark.skipif(
not shutil.which("aarch64-linux-gnu-gcc"), reason="cross-compilation toolchain not installed"
)
def test_cross_compile_options_aarch64_paddle_module(paddle_resnet50):
# some CI environments wont offer paddle, so skip in case it is not present
pytest.importorskip("paddle")

fake_sysroot_dir = utils.tempdir().relpath("")

tvmc_model = tvmc.load(paddle_resnet50, "paddle")
tvmc_package = tvmc.compile(
tvmc_model,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
cross="aarch64-linux-gnu-gcc",
cross_options="--sysroot=" + fake_sysroot_dir,
)
dumps_path = tvmc_package.package_path + ".asm"

# check for output types
assert type(tvmc_package) is TVMCPackage
assert type(tvmc_package.graph) is str
assert type(tvmc_package.lib_path) is str
assert type(tvmc_package.params) is bytearray
assert os.path.exists(dumps_path)


@tvm.testing.requires_opencl
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")
Expand Down
18 changes: 18 additions & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def test_guess_frontend_tensorflow():
assert type(sut) is tvmc.frontends.TensorflowFrontend


def test_guess_frontend_paddle():
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")

sut = tvmc.frontends.guess_frontend("a_model.pdmodel")
assert type(sut) is tvmc.frontends.PaddleFrontend


def test_guess_frontend_invalid():
with pytest.raises(TVMCException):
tvmc.frontends.guess_frontend("not/a/file.txt")
Expand Down Expand Up @@ -161,6 +169,16 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant):
assert "MobilenetV1/Conv2d_0/weights" in tvmc_model.params.keys()


def test_load_model__paddle(paddle_resnet50):
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")

tvmc_model = tvmc.load(paddle_resnet50, model_format="paddle")
assert type(tvmc_model) is TVMCModel
assert type(tvmc_model.mod) is IRModule
assert type(tvmc_model.params) is dict


def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")
Expand Down
25 changes: 25 additions & 0 deletions tests/python/driver/tvmc/test_tvmc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,31 @@ def _is_layout_transform(node):
assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50):
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")

tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle")
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
Expand Down

0 comments on commit f5f34f6

Please sign in to comment.