From 6bf1991c0a3d5d1d11bb3b0ad27c9f5dfb69273b Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 29 Jul 2019 15:39:59 -0700 Subject: [PATCH 1/3] init --- python/tvm/relay/expr_functor.py | 3 +- python/tvm/relay/transform.py | 39 +++++++++++++++++++++++++ tests/python/relay/test_change_batch.py | 12 ++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/python/relay/test_change_batch.py diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index e814f87083bf..6f8b19aab05b 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -195,9 +195,10 @@ class ExprMutator(ExprFunctor): and reconstructs the AST. """ def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) return Function( - list(fn.params), + list(new_params), new_body, fn.ret_type, fn.type_params, diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 3c53eb323c79..f80fa79ff8ed 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -22,7 +22,9 @@ import inspect import functools +import tvm from tvm._ffi.runtime_ctypes import TVMContext +from tvm import relay from . import _transform from .base import RelayNode, register_relay_node from .. import nd as _nd @@ -908,3 +910,40 @@ def create_function_pass(pass_arg): if pass_func: return create_function_pass(pass_func) return create_function_pass + +@function_pass(opt_level=1) +class ChangeBatch: + def __init__(self, data, batch_size=16): + """ + Change the batch size. + + Parameters + ---------- + data: Dict[relay.Var, int] + A dictionary of all the params to change. + The keys are all params, and the values is which dimension hold the batch. + + batch_size: int + The batch size to change to. + + Returns + ------- + pass: FunctionPass + The pass. + """ + self.data = data + self.batch_size = batch_size + + def transform_function(self, func, mod, ctx): + func = relay.Function(func.params, func.body, None, func.type_params, func.attrs) + change_batch = self + class ChangeBatchMutator(tvm.relay.ExprMutator): + def visit_var(self, var): + if var in change_batch.data: + ty = var.type_annotation + new_shape = list(ty.shape) + new_shape[change_batch.data[var]] = change_batch.batch_size + return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype)) + else: + return var + return ChangeBatchMutator().visit(func) diff --git a/tests/python/relay/test_change_batch.py b/tests/python/relay/test_change_batch.py new file mode 100644 index 000000000000..5cb32f1d4f64 --- /dev/null +++ b/tests/python/relay/test_change_batch.py @@ -0,0 +1,12 @@ +import tvm +from tvm import relay +from tvm.relay.testing import resnet +from tvm.relay import transform + +def test_change_batch_resnet(): + net, params = resnet.get_workload() + new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net) + assert new_net["main"].checked_type.ret_type == relay.TensorType((123, 1000)) + +if __name__ == "__main__": + test_change_batch_resnet() From f6f1c961a31e535e46e5018cd5b9db4463dbec75 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 29 Jul 2019 16:42:31 -0700 Subject: [PATCH 2/3] lint --- tests/python/relay/test_change_batch.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/relay/test_change_batch.py b/tests/python/relay/test_change_batch.py index 5cb32f1d4f64..e822bbb05910 100644 --- a/tests/python/relay/test_change_batch.py +++ b/tests/python/relay/test_change_batch.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. import tvm from tvm import relay from tvm.relay.testing import resnet From 3b083ec230923b836825b7909a82d96fb56d98d2 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 29 Jul 2019 16:53:42 -0700 Subject: [PATCH 3/3] lint --- python/tvm/relay/expr_functor.py | 4 ++-- python/tvm/relay/transform.py | 32 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 6f8b19aab05b..c609d556aca6 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -215,8 +215,8 @@ def visit_call(self, call): new_args = [self.visit(arg) for arg in call.args] return Call(new_fn, new_args, call.attrs) - def visit_var(self, rvar): - return rvar + def visit_var(self, var): + return var def visit_global_id(self, global_var): return global_var diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index f80fa79ff8ed..83837591ea41 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring """ Relay pass transformation infrastructure. """ @@ -913,24 +913,24 @@ def create_function_pass(pass_arg): @function_pass(opt_level=1) class ChangeBatch: - def __init__(self, data, batch_size=16): - """ - Change the batch size. + """ + Change the batch size. - Parameters - ---------- - data: Dict[relay.Var, int] - A dictionary of all the params to change. - The keys are all params, and the values is which dimension hold the batch. + Parameters + ---------- + data: Dict[relay.Var, int] + A dictionary of all the params to change. + The keys are all params, and the values is which dimension hold the batch. - batch_size: int - The batch size to change to. + batch_size: int + The batch size to change to. - Returns - ------- - pass: FunctionPass - The pass. - """ + Returns + ------- + pass: FunctionPass + The pass. + """ + def __init__(self, data, batch_size=16): self.data = data self.batch_size = batch_size