diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index e814f87083bf..c609d556aca6 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -195,9 +195,10 @@ class ExprMutator(ExprFunctor): and reconstructs the AST. """ def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) return Function( - list(fn.params), + list(new_params), new_body, fn.ret_type, fn.type_params, @@ -214,8 +215,8 @@ def visit_call(self, call): new_args = [self.visit(arg) for arg in call.args] return Call(new_fn, new_args, call.attrs) - def visit_var(self, rvar): - return rvar + def visit_var(self, var): + return var def visit_global_id(self, global_var): return global_var diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 3c53eb323c79..83837591ea41 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring """ Relay pass transformation infrastructure. """ @@ -22,7 +22,9 @@ import inspect import functools +import tvm from tvm._ffi.runtime_ctypes import TVMContext +from tvm import relay from . import _transform from .base import RelayNode, register_relay_node from .. import nd as _nd @@ -908,3 +910,40 @@ def create_function_pass(pass_arg): if pass_func: return create_function_pass(pass_func) return create_function_pass + +@function_pass(opt_level=1) +class ChangeBatch: + """ + Change the batch size. + + Parameters + ---------- + data: Dict[relay.Var, int] + A dictionary of all the params to change. + The keys are all params, and the values is which dimension hold the batch. + + batch_size: int + The batch size to change to. + + Returns + ------- + pass: FunctionPass + The pass. + """ + def __init__(self, data, batch_size=16): + self.data = data + self.batch_size = batch_size + + def transform_function(self, func, mod, ctx): + func = relay.Function(func.params, func.body, None, func.type_params, func.attrs) + change_batch = self + class ChangeBatchMutator(tvm.relay.ExprMutator): + def visit_var(self, var): + if var in change_batch.data: + ty = var.type_annotation + new_shape = list(ty.shape) + new_shape[change_batch.data[var]] = change_batch.batch_size + return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype)) + else: + return var + return ChangeBatchMutator().visit(func) diff --git a/tests/python/relay/test_change_batch.py b/tests/python/relay/test_change_batch.py new file mode 100644 index 000000000000..e822bbb05910 --- /dev/null +++ b/tests/python/relay/test_change_batch.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.testing import resnet +from tvm.relay import transform + +def test_change_batch_resnet(): + net, params = resnet.get_workload() + new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net) + assert new_net["main"].checked_type.ret_type == relay.TensorType((123, 1000)) + +if __name__ == "__main__": + test_change_batch_resnet()