Skip to content

Commit

Permalink
[Relay][VTA] Add ChangeBatch pass (apache#3656)
Browse files Browse the repository at this point in the history
* init

* lint

* lint
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 6, 2019
1 parent b579791 commit 66aad8d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ class ExprMutator(ExprFunctor):
and reconstructs the AST.
"""
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
return Function(
list(fn.params),
list(new_params),
new_body,
fn.ret_type,
fn.type_params,
Expand All @@ -214,8 +215,8 @@ def visit_call(self, call):
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_var(self, rvar):
return rvar
def visit_var(self, var):
return var

def visit_global_id(self, global_var):
return global_var
Expand Down
41 changes: 40 additions & 1 deletion python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring
"""
Relay pass transformation infrastructure.
"""
import types
import inspect
import functools

import tvm
from tvm._ffi.runtime_ctypes import TVMContext
from tvm import relay
from . import _transform
from .base import RelayNode, register_relay_node
from .. import nd as _nd
Expand Down Expand Up @@ -908,3 +910,40 @@ def create_function_pass(pass_arg):
if pass_func:
return create_function_pass(pass_func)
return create_function_pass

@function_pass(opt_level=1)
class ChangeBatch:
"""
Change the batch size.
Parameters
----------
data: Dict[relay.Var, int]
A dictionary of all the params to change.
The keys are all params, and the values is which dimension hold the batch.
batch_size: int
The batch size to change to.
Returns
-------
pass: FunctionPass
The pass.
"""
def __init__(self, data, batch_size=16):
self.data = data
self.batch_size = batch_size

def transform_function(self, func, mod, ctx):
func = relay.Function(func.params, func.body, None, func.type_params, func.attrs)
change_batch = self
class ChangeBatchMutator(tvm.relay.ExprMutator):
def visit_var(self, var):
if var in change_batch.data:
ty = var.type_annotation
new_shape = list(ty.shape)
new_shape[change_batch.data[var]] = change_batch.batch_size
return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
else:
return var
return ChangeBatchMutator().visit(func)
28 changes: 28 additions & 0 deletions tests/python/relay/test_change_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay.testing import resnet
from tvm.relay import transform

def test_change_batch_resnet():
net, params = resnet.get_workload()
new_net = transform.ChangeBatch({net["main"].params[0]: 0}, batch_size=123)(net)
assert new_net["main"].checked_type.ret_type == relay.TensorType((123, 1000))

if __name__ == "__main__":
test_change_batch_resnet()

0 comments on commit 66aad8d

Please sign in to comment.