Skip to content

Commit

Permalink
onnx model saving was implemented through io.BytesIO. creating/removi…
Browse files Browse the repository at this point in the history
…ng tmp dir was removed. remove unneccessary comments
  • Loading branch information
vvchernov committed Jul 19, 2021
1 parent 67d380e commit a944157
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions tests/python/frontend/pytorch/test_lstms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import numpy as np
import torch
import onnx
import io
import sys
import shutil
import pytest

from tvm import relay
from tvm.contrib import graph_executor

from pathlib import Path
from torch import nn

## Model parameters
Expand Down Expand Up @@ -74,7 +73,6 @@ def __init__(
self.batch_first = batch_first
self.use_bias = use_bias

# Network defition
if check_torch_version_for_proj_in_lstm():
self.lstm = nn.LSTM(
input_size=model_feature_size,
Expand Down Expand Up @@ -173,18 +171,9 @@ def get_dummy_input(self):

def compare(input, gold_data, rtol=1e-5, atol=1e-5):
tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol)
# remain = np.abs(gold_data - input)
# err = np.max(remain)
# if err < 1e-6:
# print("SUCCESS: RESULTS ARE THE SAME WITH MAX ERROR {} AND EPSILON {}".format(err, 1e-6))
# else:
# print("WARNING: RESULTS ARE NOT THE SAME WITH ERROR {}".format(err))


def check_lstm_with_type(lstm_type):
# Create outdir directory to keep temporal files
out_dir = Path.cwd().joinpath("output")
out_dir.mkdir(exist_ok=True, parents=True)
has_proj = "p" in lstm_type

device = torch.device("cpu")
Expand Down Expand Up @@ -296,7 +285,7 @@ def check_lstm_with_type(lstm_type):
mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list)

# Model compilation by tvm
target = tvm.target.Target("llvm", host="llvm")
target = tvm.target.Target("llvm -mcpu=core-avx2")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
Expand All @@ -307,8 +296,7 @@ def check_lstm_with_type(lstm_type):
"from pytorch! TODO: waiting for the support and correct test after that."
)
continue
onnx_fpath = out_dir.joinpath("model_{}.onnx".format(lstm_type))

onnx_io = io.BytesIO()
with torch.no_grad():
h0 = torch.rand(input_hidden_shape)
if has_proj:
Expand All @@ -318,10 +306,10 @@ def check_lstm_with_type(lstm_type):

# default export (without dynamic input)
torch.onnx.export(
model, (dummy_input, (h0, c0)), onnx_fpath, input_names=input_names
model, (dummy_input, (h0, c0)), onnx_io, input_names=input_names
)

onnx_model = onnx.load(onnx_fpath)
onnx_io.seek(0,0)
onnx_model = onnx.load_model(onnx_io)

# Import model to Relay
shape_dict = {
Expand All @@ -338,7 +326,7 @@ def check_lstm_with_type(lstm_type):
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

# Model compilation by tvm
target = "llvm"
target = tvm.target.Target("llvm -mcpu=core-avx2")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=1):
lib = relay.build(mod, target=target, params=params)
Expand All @@ -359,9 +347,6 @@ def check_lstm_with_type(lstm_type):

compare(tvm_output, golden_output_batch)

# Remove output directory with tmp files
shutil.rmtree(out_dir)


def test_lstms():
check_lstm_with_type("uni")
Expand Down

0 comments on commit a944157

Please sign in to comment.