Skip to content

Commit

Permalink
[Relax] Add CopyWithNewParams interface (#91)
Browse files Browse the repository at this point in the history
* copy_func pass

* reformat

* move interface into utils

* simplify test

* Change name to copy_with_new_params

* remove zero-param constructor

* change visitor into Transform

* formatted

* finish revise

* revise
  • Loading branch information
Ubospica authored and MasterJH5574 committed Jan 28, 2023
1 parent a949b56 commit f1bd8a9
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 1 deletion.
9 changes: 9 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
*/
TVM_DLL bool IsLeafOrTuple(const Expr& expr);

/*!
* \brief Copy the given function. The parameters of the original function would be copied to
* satisfy the restriction in the well-formed check: any two functions cannot share the same
* parameter variable.
* \param func The relax function to copy.
* \return The copied function.
*/
TVM_DLL Function CopyWithNewParams(Function func);

} // namespace relax
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import transform
from . import expr_functor
from . import struct_info
from . import utils

# Expr

Expand Down
21 changes: 20 additions & 1 deletion python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .. import tir
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from .expr import Expr, PrimValue, ShapeExpr, StringImm
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
from .expr import Tuple as rx_Tuple


Expand Down Expand Up @@ -254,3 +255,21 @@ def auto(func: FType) -> FType:


args_converter = _ArgsConverter() # pylint: disable=invalid-name


def copy_with_new_params(func: Function) -> Function:
"""Copy the given function. The parameters of the original function would be copied to
satisfy the restriction in the well-formed check: any two functions cannot share the same
parameter variable.
Parameters
----------
func : Function
The relax function to copy.
Returns
-------
ret : Function
The copied function.
"""
return _ffi_api.CopyWithNewParams(func) # type: ignore
23 changes: 23 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,28 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}

class FunctionCopier : public ExprMutator {
public:
static Function Transform(Function func) {
FunctionCopier copier;
// the parameters would be copied and substituted to satisfy the restriction in the well-formed
// check: any two functions cannot share the same parameter variable.
Array<Var> new_params;
for (Var param : func->params) {
Var new_param = Var(param->vid, GetStructInfo(param), param->span);
copier.var_remap_[param->vid] = new_param;
new_params.push_back(new_param);
}

Expr body = copier.VisitWithNewScope(func->body, new_params);

return Function(new_params, body, func->ret_struct_info, func->attrs);
}
};

Function CopyWithNewParams(Function func) { return FunctionCopier::Transform(func); }

TVM_REGISTER_GLOBAL("relax.CopyWithNewParams").set_body_typed(CopyWithNewParams);

} // namespace relax
} // namespace tvm
38 changes: 38 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import relax as R


def test_copy_with_new_params():
@R.function
def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv

after = relax.utils.copy_with_new_params(before)
assert_structural_equal(after, before)

assert len(after.params) == len(before.params)
for before_var, after_var in zip(before.params, after.params):
assert before_var != after_var


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f1bd8a9

Please sign in to comment.