Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TVM Vertical Integration with PyTorch #11911

Merged
merged 28 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for tvm torch module"""
import numpy as np

import torch
import torch.nn

import tvm
import tvm.testing
from tvm.contrib.torch import as_torch
from tvm.script import tir as T


@as_torch
def matmul(M: int, N: int, K: int, dtype: str):
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [M, K], dtype=dtype)
B = T.match_buffer(b, [N, K], dtype=dtype)
C = T.match_buffer(c, [M, N], dtype=dtype)
for i, j, k in T.grid(M, N, K):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

return main


@as_torch
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
vi = T.axis.spatial(8, i)
B[vi] = A[vi] + 1.0


@as_torch
@tvm.script.ir_module
class ModuleGPU:
@T.prim_func
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i_0 in T.thread_binding(2, thread="blockIdx.x"):
for i_2 in T.thread_binding(2, thread="threadIdx.x"):
for i_1 in T.serial(2):
with T.block("B"):
vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
T.reads(A[vi])
T.writes(B[vi])
B[vi] = A[vi] + T.float32(1)


class MinuesOnes(torch.nn.Module):
def __init__(self):
super(MinuesOnes, self).__init__()
self.engine = MyModule

def forward(self, *input):
self.engine.forward(*input)
return input[-1] - 1


def test_tvmscript_torch_matmul():
s1 = np.random.rand(128, 128).astype("float32")
s2 = np.random.rand(128, 128).astype("float32")
s3 = np.random.rand(128, 128).astype("float32")

q1 = torch.from_numpy(s1)
q2 = torch.from_numpy(s2)
q3 = torch.from_numpy(s3)

numpy_result = np.matmul(s1, np.transpose(s2))

nn_module = matmul(128, 128, 128, "float32")

nn_module(q1, q2, q3)

tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5)


def test_tvmscript_torch_decorator():
q1 = torch.arange(8).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32)

MyModule(q1, q2)

tvm.testing.assert_allclose(q2.numpy(), (q1 + 1).numpy(), atol=1e-5, rtol=1e-5)


def test_tvmscript_torch_gpu():
cuda0 = torch.device("cuda:0")
q1 = torch.arange(8, device=cuda0).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)

ModuleGPU(q1, q2)

tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)


def test_torch_with_tvmscript():
ref_result = np.arange(8).astype("float32")

q1 = torch.arange(8).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32)

nn_module = MinuesOnes()

ret = nn_module.forward(q1, q2)

tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_tvmscript_torch_matmul()
test_tvmscript_torch_decorator()
test_tvmscript_torch_gpu()
test_torch_with_tvmscript()
161 changes: 161 additions & 0 deletions apps/pt_tvmdsoop/tests/test_optimize_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# pylint: disable=missing-class-docstring
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for tvm torch module"""
import tempfile

import torch
from torch.utils import benchmark
from torchvision.models import resnet18

import tvm
import tvm.testing
from tvm.contrib.torch import optimize_torch
from tvm.meta_schedule import TuneConfig


def test_matmul_tuning_relay():
def matmul(x, w):
return torch.matmul(x, w)

x = torch.randn(15, 20)
w = torch.randn(20, 30)
example_inputs = (x, w)

rt_mod = optimize_torch(matmul, example_inputs)
torch_answer = torch.matmul(x, w).numpy()
tvm_answer = rt_mod(x, w).numpy()

tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5)


class InnerModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 20, 5)

def forward(self, x):
return torch.nn.functional.relu(self.conv(x))


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(20, 20, 5)
self.relu = InnerModel()

def forward(self, x):
x = self.relu(x)
return torch.nn.functional.relu(self.conv(x))


def test_nested_module():
simple_module = SimpleModel()
example_input = torch.randn(20, 1, 10, 10)
optimized_module = optimize_torch(simple_module, example_input)
ret1 = simple_module(example_input).detach().numpy()
ret2 = optimized_module(example_input).detach().numpy()
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)


def test_save_load_function():
def foo(x):
return 2 * x + 1

example_input = torch.rand(3)
opt_foo = optimize_torch(foo, example_input)
ret1 = opt_foo(example_input)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
torch.save(opt_foo, tmp.name)
loaded_mod = torch.load(tmp.name)
ret2 = loaded_mod(example_input)
tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5)


class MyResNet18(torch.nn.Module):
def __init__(self, config, target=None):
super(MyResNet18, self).__init__()
self.means = torch.nn.Parameter(
torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
).cuda()
self.resnet = optimize_torch(resnet18(), [torch.rand(1, 3, 224, 224)], config, target)

def forward(self, input):
return self.resnet(input - self.means)


class JitModule(torch.nn.Module):
def __init__(self):
super(JitModule, self).__init__()
self.means = torch.nn.Parameter(
torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
).cuda()
self.resnet = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))

def forward(self, input):
return self.resnet(input - self.means)


# default config for testing
config = TuneConfig(
strategy="evolutionary",
num_trials_per_iter=4,
max_trials_per_task=8,
max_trials_global=16,
)

if torch.cuda.is_available():
target_cuda = "nvidia/geforce-rtx-3070"
meta_module_resnet18 = MyResNet18(config, target_cuda)
jit_module_resnet18 = JitModule()


def compare_optimize_resnet18_to_torchscript():
results = []
for i in range(20):
test_input = torch.rand(1, 3, 224, 224).half().cuda()
sub_label = f"[test {i}]"
results.append(
benchmark.Timer(
stmt="meta_module_resnet18(test_input)",
setup="from __main__ import meta_module_resnet18",
globals={"test_input": test_input},
sub_label=sub_label,
description="tuning by meta",
).blocked_autorange()
)
results.append(
benchmark.Timer(
stmt="jit_module_resnet18(test_input)",
setup="from __main__ import jit_module_resnet18",
globals={"test_input": test_input},
sub_label=sub_label,
description="tuning by jit",
).blocked_autorange()
)
compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
test_matmul_tuning_relay()
test_nested_module()
test_save_load_function()
if torch.cuda.is_available():
compare_optimize_resnet18_to_torchscript()
12 changes: 11 additions & 1 deletion python/tvm/contrib/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import platform
import torch
from tvm._ffi import libinfo
from tvm.relay.frontend import pytorch


def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
Expand All @@ -39,6 +38,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):

_load_platform_specific_library()


from . import module

GraphModule = module.GraphModule
Expand All @@ -49,3 +49,13 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):

PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule
compile = pytorch_tvm.compile

from . import as_torch

TVMScriptIRModule = as_torch.OperatorModuleWrapper
as_torch = as_torch.as_torch

from . import optimize_torch

GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper
optimize_torch = optimize_torch.optimize_torch
Loading