From 67d0b502eb18c1343247820d49095db961fa1683 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Wed, 25 Jan 2023 18:50:22 +0800 Subject: [PATCH 01/17] init --- include/tvm/relax/utils.h | 18 ++++ python/tvm/relax/utils.py | 51 ++++++++++++ src/relax/utils.cc | 137 +++++++++++++++++++++++++++++++ tests/python/relax/test_utils.py | 118 +++++++++++++++++++++++++- 4 files changed, 322 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1457a16427..1826491d8d 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -149,6 +149,24 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr); */ TVM_DLL Function CopyWithNewParams(Function func); +/*! + * \brief Extend a relax function by another given function. It will link orig_func with + * ex_func and return a new function. + * + * In detail, the result function has the arguments list of orig_func and the combination + * of their body, which passes the return values of orig_func as the arguments of ex_func. For + * those arguments of ex_func which are not mapped to some return values, they will be lifted and + * appended to the argument list of result function. + * + * This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some + * sense. + * + * \param orig_func The function to be extended. + * \param ex_func The function to be linked after the orig_func. + * \return The result function after extending. + */ +TVM_DLL Function ExtendFunc(Function orig_func, Function ex_func); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index e1d5bf50c1..4598ec4c0f 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -273,3 +273,54 @@ def copy_with_new_params(func: Function) -> Function: The copied function. """ return _ffi_api.CopyWithNewParams(func) # type: ignore + + +def extend_func(orig_func: Function, ex_func: Function) -> Function: + """Extend a relax function by another given function. It will link orig_func with + ex_func and return a new function. + + In detail, the result function has the arguments list of orig_func and the combination + of their body, which passes the return values of orig_func as the arguments of ex_func. For + those arguments of ex_func which are not mapped to some return values, they will be lifted and + appended to the argument list of result function. + + This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some + sense. + + Note: the return value of orig_func will be bound to DataflowVar. So it is a bad idea to use this + util if the params of ex_func present in its R.output. + + Example: + + .. code-block:: python + # Before. + @R.function + def func1(a, b): + return a + b, a * b + + @R.function + def func2(c, d, e): + return d, c, c + e + + # After. func1_func2 = extend_func(orig_func=func1, ex_func=func2). + @R.function + def func1_func2(a, b, e): + c = a + b + d = a * b + return d, c, c + e + + Parameters + ---------- + orig_func : Function + The function to be extended. + + ex_func : Function + The function to be linked after the orig_func. + + Returns + ------- + ret : Function + The result function. + """ + + return _ffi_api.ExtendFunc(orig_func, ex_func) # type: ignore diff --git a/src/relax/utils.cc b/src/relax/utils.cc index a77e4342e2..0c812e8f95 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -105,5 +105,142 @@ Function CopyWithNewParams(Function func) { return FunctionCopier::Transform(fun TVM_REGISTER_GLOBAL("relax.CopyWithNewParams").set_body_typed(CopyWithNewParams); +/*! \brief Helper to implement extend function.*/ +class ExtendFuncMutator : public ExprMutator { + public: + explicit ExtendFuncMutator(const SeqExpr& ex_body) : ex_body_(ex_body) {} + + Expr VisitExpr_(const SeqExprNode* seq_expr) override { + // mutate only the last block. + Array blocks; + for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { + if (i < static_cast(seq_expr->blocks.size()) - 1) { + blocks.push_back(seq_expr->blocks[i]); + } else { + BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[i]); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + } + } + this->VisitExpr(seq_expr->body); + return SeqExpr(blocks, ex_body_->body); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + builder_->BeginDataflowBlock(); + // emit original bindings. + for (const auto& binding : block->bindings) { + this->VisitBinding(binding); + } + + ICHECK(orig_rets_var_.size() == orig_rets.size()); + for (int i = 0; i < static_cast(orig_rets_var_.size()); ++i) { + if (orig_rets_var_[i].defined()) { + builder_->EmitNormalized(VarBinding(orig_rets_var_[i].value(), orig_rets[i])); + } + } + + // emit blocks for extend part. + for (BindingBlock block : ex_body_->blocks) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + } + + return builder_->EndBlock(); + } + + void VisitBinding_(const VarBindingNode* binding) override { + Var new_var = Downcast(this->VisitExpr(binding->var)); + Expr new_value = this->VisitExpr(binding->value); + builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + // remap orignal dataflow var + // TODO(chaofan): a better way to check whether new_ret_var should be dataflow + void RemapToDataflow(SeqExpr body) { + for (BindingBlock block : body->blocks) { + for (Binding binding : block->bindings) { + const auto* binding_node = binding.as(); + if (binding_node && !binding_node->var->IsInstance()) { + Var new_binding_var = DataflowVar( + binding_node->var->vid, GetStructInfo(binding_node->var), binding_node->var->span); + this->var_remap_[binding_node->var->vid] = new_binding_var; + } + } + } + } + + Array RemapExParams(const Array& ex_func_params, Array new_params) { + for (int i = 0; i < static_cast(ex_func_params.size()); ++i) { + Var ex_param = ex_func_params[i]; + if (i < static_cast(orig_rets.size())) { + // map return value to ex param + if (const auto* var_node = orig_rets[i].as()) { + ICHECK(orig_rets[i].as()); + orig_rets_var_.push_back(NullOpt); + this->var_remap_[ex_param->vid] = GetRef(var_node); + } else { + Var new_ret_var = + DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), GetStructInfo(orig_rets[i])); + orig_rets_var_.push_back(new_ret_var); + this->var_remap_[ex_param->vid] = new_ret_var; + } + } else { + // append to the param list + Var new_ex_param = Var(ex_param->vid, GetStructInfo(ex_param), ex_param->span); + this->var_remap_[ex_param->vid] = new_ex_param; + new_params.push_back(new_ex_param); + } + } + return new_params; + } + + Array orig_rets; + + private: + SeqExpr ex_body_; + Array> orig_rets_var_; +}; + +/*! + * \brief Extend a relax function by another given function. + * \param orig_func The function to be extended. + * \param ex_func The function to be linked after the orig_func. + * \return The result function after extending. + */ +Function ExtendFunc(Function orig_func, Function ex_func) { + CHECK(orig_func->body->IsInstance()) + << "the body of the original function is not SeqExpr."; + CHECK(ex_func->body->IsInstance()) << "the body of the ex function is not SeqExpr."; + + auto param_copied_func = CopyWithNewParams(orig_func); + auto seq_expr = Downcast(param_copied_func->body); + + ExtendFuncMutator mutator(Downcast(ex_func->body)); + mutator.RemapToDataflow(seq_expr); + // Get the orignal rets. If it is a Tuple, unpack it. + if (orig_func->ret_struct_info.as()) { + const auto* tuple_node = seq_expr->body.as(); + ICHECK(tuple_node != nullptr); + for (Expr field : tuple_node->fields) { + mutator.orig_rets.push_back(mutator.VisitExpr(field)); + } + } else { + mutator.orig_rets.push_back(mutator.VisitExpr(seq_expr->body)); + } + + CHECK(ex_func->params.size() >= mutator.orig_rets.size()) + << "The number of return values of original functions should be greater than the number of " + "parameters of ex function"; + + auto new_params = mutator.RemapExParams(ex_func->params, param_copied_func->params); + Expr new_body = mutator.VisitExpr(seq_expr); + return Function(new_params, new_body, ex_func->ret_struct_info, param_copied_func->attrs); +} + +TVM_REGISTER_GLOBAL("relax.ExtendFunc").set_body_typed(ExtendFunc); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 1cf2b56fa9..c43b23a519 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest +import tvm.testing from tvm import relax from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R @@ -34,5 +34,119 @@ def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): assert before_var != after_var +def test_extend_func_basic_extend(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.sum(y) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def ex(arg1: R.Tensor((), dtype="float32"), arg2: R.Tensor((), dtype="float32")): + R.func_attr({"global_symbol": "ex"}) + with R.dataflow(): + gv0 = R.add(arg1, arg2) + R.output(gv0) + return gv0 + + @R.function + def orig_ex( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=0): + # block 0 + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv01: R.Tensor((), dtype="float32") = R.add(gv0, gv1) + R.output(gv01) + return gv01 + + after = relax.utils.extend_func(orig, ex) + assert_structural_equal(after, orig_ex) + + +def test_extend_func_extra_params(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.add(x, x) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def ex( + arg1: R.Tensor((), dtype="float32"), + arg2: R.Tensor((3, 3), dtype="float32"), + arg3: R.Tensor((3, 3), dtype="float32"), + ): + R.func_attr({"global_symbol": "ex"}) + with R.dataflow(): + gv0 = R.add(arg2, arg3) + R.output(gv0) + return gv0 + + @R.function + def orig_ex( + x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=2): + # block 0 + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv01: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3) + R.output(gv01) + return gv01 + + after = relax.utils.extend_func(orig, ex) + assert_structural_equal(after, orig_ex) + + +def test_extend_func_nested_tuple(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.add(x, x) + gv1 = R.sum(y) + gv2 = R.add(x, y) + R.output(gv0, gv1, gv2) + return (gv0, gv1), gv2 + + @R.function + def ex( + arg1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")), + arg2: R.Tensor((), dtype="float32"), + ): + R.func_attr({"global_symbol": "ex"}) + with R.dataflow(): + arg10 = arg1[0] + gv0 = R.add(arg10, arg2) + R.output(gv0) + return gv0 + + @R.function + def orig_ex( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor((3, 3), dtype="float32"): + # block 0 + with R.dataflow(): + gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv2: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + ret_0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")) = ( + gv0, + gv1, + ) + arg10: R.Tensor((3, 3), dtype="float32") = ret_0[0] + gv01: R.Tensor((3, 3), dtype="float32") = R.add(arg10, gv2) + R.output(gv01) + return gv01 + + after = relax.utils.extend_func(orig, ex) + assert_structural_equal(after, orig_ex) + + if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() From 72affb774c879925bfc2a3160b8226841322bfeb Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Wed, 25 Jan 2023 19:04:59 +0800 Subject: [PATCH 02/17] lint --- python/tvm/relax/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 4598ec4c0f..6ef392bb1e 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -287,8 +287,8 @@ def extend_func(orig_func: Function, ex_func: Function) -> Function: This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some sense. - Note: the return value of orig_func will be bound to DataflowVar. So it is a bad idea to use this - util if the params of ex_func present in its R.output. + Note: the return value of orig_func will be bound to DataflowVar. So it is a bad idea to use + this util if the params of ex_func present in its R.output. Example: From d7cdd59c745554146ac105e593c13ac16b9cd416 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Wed, 25 Jan 2023 21:56:15 +0800 Subject: [PATCH 03/17] draft --- python/tvm/relax/training/__init__.py | 1 + python/tvm/relax/training/loss.py | 158 ++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 python/tvm/relax/training/loss.py diff --git a/python/tvm/relax/training/__init__.py b/python/tvm/relax/training/__init__.py index 2cf602cb4f..9ecc0573d5 100644 --- a/python/tvm/relax/training/__init__.py +++ b/python/tvm/relax/training/__init__.py @@ -17,3 +17,4 @@ """The Relax training APIs.""" from . import optimizer +from . import loss diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py new file mode 100644 index 0000000000..b9439f5e1a --- /dev/null +++ b/python/tvm/relax/training/loss.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Loss functions library for relax.""" + +from typing import Any, List, Optional, Union +from tvm.script.parser import relax as R +from tvm.relax import Var, Function, StructInfo, BlockBuilder +from tvm import relax + + +def _create_param_var(param: Union[Var, StructInfo], param_name) -> Var: + if isinstance(param, StructInfo): + param = Var(param_name, param) + assert isinstance(param, Var) + return param + + +class Loss: + """Base class of all loss. + + Parameters + ---------- + """ + + reduction: str + loss_name: str + + def __init__(self, loss_name: str, reduction: str = "mean") -> None: + self.loss_name = loss_name + self.reduction = reduction + + def __call__(self) -> Function: + raise NotImplementedError() + + +class L1Loss(Loss): + """Mean element-wise absolute value difference. + + Parameters + ---------- + """ + + def __init__(self, reduction: str = "mean") -> None: + super(L1Loss, self).__init__("l1_loss", reduction) + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + ) -> Function: + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + with bb.function(self.loss_name, [predictions, targets]): + with bb.dataflow(): + lv = bb.emit(R.subtract(logits, targets)) + if self.reduction == "none": + loss = bb.emit_output(R.abs(lv)) # TODO: R.abs + else: + loss = bb.emit(R.abs(lv)) + if self.reduction == "sum": + loss = bb.emit_output(R.sum(loss)) + else: + # TODO: mean + pass + bb.emit_func_output(loss) + + return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) + + +class MSELoss(Loss): + """Measures the element-wise mean squared error. + + Parameters + ---------- + """ + + def __init__(self, reduction: str = "mean") -> None: + super(MSELoss, self).__init__("mse_loss", reduction) + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + ) -> Function: + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + with bb.function(self.loss_name, [predictions, targets]): + with bb.dataflow(): + lv = bb.emit(R.subtract(logits, targets)) + if self.reduction == "none": + loss = bb.emit_output(R.mutiply(lv, lv)) + else: + loss = bb.emit(R.mutiply(lv, lv)) + if self.reduction == "sum": + loss = bb.emit_output(R.sum(loss)) + else: + # TODO: mean + pass + bb.emit_func_output(loss) + + return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) + + +class CrossEntropyLoss(Loss): + """CrossEntropyLoss. + + Parameters + ---------- + """ + + ignore_index: int + + def __init__(self, ignore_index: int = -100, reduction: str = "mean") -> None: + super(CrossEntropyLoss, self).__init__("cross_entropy_loss", reduction) + self.ignore_index = ignore_index + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + weights: Optional[Union[Var, StructInfo]], + ) -> Function: + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + if weights: + weights = _create_param_var(weights, "predictions") + + with bb.function(self.loss_name, [predictions, targets, weights]): + with bb.dataflow(): + logits = bb.emit(R.nn.log_softmax(predictions)) + loss = bb.emit_output( + R.nn.nll_loss(logits, targets, weights, self.reduction, self.ignore_index) + ) + bb.emit_func_output(loss) + + return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) From af3934a0d323a191a3232dd6727dd193cf625414 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Thu, 26 Jan 2023 14:59:23 +0800 Subject: [PATCH 04/17] append_loss refactor and add tests --- include/tvm/relax/utils.h | 18 --- python/tvm/relax/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 2 +- python/tvm/relax/training/__init__.py | 5 +- python/tvm/relax/training/_ffi_api.py | 19 +++ python/tvm/relax/training/loss.py | 111 ++++++++++---- python/tvm/relax/training/utils.py | 68 +++++++++ python/tvm/relax/utils.py | 51 ------- src/relax/training/utils.cc | 167 +++++++++++++++++++++ src/relax/training/utils.h | 50 +++++++ src/relax/utils.cc | 137 ----------------- tests/python/relax/test_training_loss.py | 175 ++++++++++++++++++++++ tests/python/relax/test_training_utils.py | 135 +++++++++++++++++ tests/python/relax/test_utils.py | 114 -------------- 14 files changed, 699 insertions(+), 354 deletions(-) create mode 100644 python/tvm/relax/training/_ffi_api.py create mode 100644 python/tvm/relax/training/utils.py create mode 100644 src/relax/training/utils.cc create mode 100644 src/relax/training/utils.h create mode 100644 tests/python/relax/test_training_loss.py create mode 100644 tests/python/relax/test_training_utils.py diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1826491d8d..1457a16427 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -149,24 +149,6 @@ TVM_DLL bool IsLeafOrTuple(const Expr& expr); */ TVM_DLL Function CopyWithNewParams(Function func); -/*! - * \brief Extend a relax function by another given function. It will link orig_func with - * ex_func and return a new function. - * - * In detail, the result function has the arguments list of orig_func and the combination - * of their body, which passes the return values of orig_func as the arguments of ex_func. For - * those arguments of ex_func which are not mapped to some return values, they will be lifted and - * appended to the argument list of result function. - * - * This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some - * sense. - * - * \param orig_func The function to be extended. - * \param ex_func The function to be linked after the orig_func. - * \return The result function after extending. - */ -TVM_DLL Function ExtendFunc(Function orig_func, Function ex_func); - } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index d6151cdc29..4f82420bbf 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -27,6 +27,7 @@ from . import expr_functor from . import struct_info from . import utils +from . import training # Expr diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index de3cb31086..3ce49ab7be 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -634,7 +634,7 @@ def nll_loss( The weight of each target value. If not specified, it is treated as if having all ones. - reduction : string + reduction : str The reduction method to apply to the output. Possible values are "mean", "sum" and "none". diff --git a/python/tvm/relax/training/__init__.py b/python/tvm/relax/training/__init__.py index 9ecc0573d5..b9da9ad5c6 100644 --- a/python/tvm/relax/training/__init__.py +++ b/python/tvm/relax/training/__init__.py @@ -17,4 +17,7 @@ """The Relax training APIs.""" from . import optimizer -from . import loss +from . import utils + +# loss functions +from .loss import * diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py new file mode 100644 index 0000000000..34dc4a0669 --- /dev/null +++ b/python/tvm/relax/training/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.training""" +import tvm._ffi + +tvm._ffi._init_api("relax.training", __name__) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index b9439f5e1a..8d468501ec 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -14,12 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=redefined-builtin """Loss functions library for relax.""" from typing import Any, List, Optional, Union -from tvm.script.parser import relax as R -from tvm.relax import Var, Function, StructInfo, BlockBuilder from tvm import relax +from ..expr import Expr, Var, Function, StructInfo + +from ..op import sum, mean, subtract, multiply +from ..op.nn import log_softmax, nll_loss + + +__all__ = ["L1Loss", "MSELoss", "CrossEntropyLoss"] def _create_param_var(param: Union[Var, StructInfo], param_name) -> Var: @@ -34,6 +40,15 @@ class Loss: Parameters ---------- + loss_name : str + The name of the loss function. + + reduction : str + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. """ reduction: str @@ -43,15 +58,42 @@ def __init__(self, loss_name: str, reduction: str = "mean") -> None: self.loss_name = loss_name self.reduction = reduction + valid_reductions = ["mean", "sum", "none"] + + if self.reduction not in valid_reductions: + raise ValueError("Reduction can only be one of these values: ", valid_reductions) + def __call__(self) -> Function: + """Calling a loss will get its relax function. + + Usually it has some parameters with type Union[Var, StructInfo]. It means + the necessary inputs of the loss function. If a struct info is given, it will + construct a corresponding Var using the struct info; if a Var is given, it will + directly use this Var as the param. + + Returns + ---------- + The relax function of the loss with the loss name as its global symbol. + """ raise NotImplementedError() + def _with_reduction(self, expr: Expr): + if self.reduction == "sum": + expr = sum(expr) + elif self.reduction == "mean": + expr = sum(mean(expr, axis=0)) + else: + assert self.reduction == "none" + return expr + class L1Loss(Loss): """Mean element-wise absolute value difference. Parameters ---------- + reduction : str + See the doc of Loss. """ def __init__(self, reduction: str = "mean") -> None: @@ -62,23 +104,15 @@ def __call__( predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], ) -> Function: - bb = BlockBuilder() + bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") with bb.function(self.loss_name, [predictions, targets]): with bb.dataflow(): - lv = bb.emit(R.subtract(logits, targets)) - if self.reduction == "none": - loss = bb.emit_output(R.abs(lv)) # TODO: R.abs - else: - loss = bb.emit(R.abs(lv)) - if self.reduction == "sum": - loss = bb.emit_output(R.sum(loss)) - else: - # TODO: mean - pass + lv = abs(subtract(predictions, targets)) # TODO: R.abs + loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) @@ -89,6 +123,8 @@ class MSELoss(Loss): Parameters ---------- + reduction : str + See the doc of Loss. """ def __init__(self, reduction: str = "mean") -> None: @@ -99,23 +135,16 @@ def __call__( predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], ) -> Function: - bb = BlockBuilder() + bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") with bb.function(self.loss_name, [predictions, targets]): with bb.dataflow(): - lv = bb.emit(R.subtract(logits, targets)) - if self.reduction == "none": - loss = bb.emit_output(R.mutiply(lv, lv)) - else: - loss = bb.emit(R.mutiply(lv, lv)) - if self.reduction == "sum": - loss = bb.emit_output(R.sum(loss)) - else: - # TODO: mean - pass + lv = subtract(predictions, targets) + lv = multiply(lv, lv) + loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) @@ -126,32 +155,50 @@ class CrossEntropyLoss(Loss): Parameters ---------- + reduction : str + See the doc of Loss. + + weights : Optional[Union[Var, StructInfo]] + a manual rescaling weight given to each class. It has to be a Tensor of size C. + + ignore_index : int + Specifies a target value that is ignored and does not contribute to the input gradient. """ ignore_index: int - def __init__(self, ignore_index: int = -100, reduction: str = "mean") -> None: + def __init__( + self, + reduction: str = "mean", + ignore_index: int = -100, + weights: Optional[Union[Var, StructInfo]] = None, + ) -> None: super(CrossEntropyLoss, self).__init__("cross_entropy_loss", reduction) self.ignore_index = ignore_index + if weights: + self.weights = _create_param_var(weights, "weights") + else: + self.weights = None def __call__( self, predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], - weights: Optional[Union[Var, StructInfo]], ) -> Function: - bb = BlockBuilder() + bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") - if weights: - weights = _create_param_var(weights, "predictions") - with bb.function(self.loss_name, [predictions, targets, weights]): + arg_list = [predictions, targets] + if self.weights: + arg_list.append(self.weights) + + with bb.function(self.loss_name, arg_list): with bb.dataflow(): - logits = bb.emit(R.nn.log_softmax(predictions)) + logits = bb.emit(log_softmax(predictions)) loss = bb.emit_output( - R.nn.nll_loss(logits, targets, weights, self.reduction, self.ignore_index) + nll_loss(logits, targets, self.weights, self.reduction, self.ignore_index) ) bb.emit_func_output(loss) diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py new file mode 100644 index 0000000000..596ba9c176 --- /dev/null +++ b/python/tvm/relax/training/utils.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for relax training.""" + +from ..expr import Function +from . import _ffi_api + + +def append_loss(orig_func: Function, loss_func: Function) -> Function: + """Local helper to append a specified loss function after the original function. + + In detail, the result function has the arguments list of orig_func and the combination + of their body, which passes the return values of orig_func as the arguments of loss_func. For + those arguments of loss_func which are not mapped to some return values, they will be lifted + and appended to the argument list of result function. + + Notice: + 1. This uitl is dedicated to loss functions, not for general purposes. + 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in + some sense. + + Example: + + .. code-block:: python + # Before. + @R.function + def orig(x, y): + out = x + y + return out + + @R.function + def loss(predictions, labels): + return R.sum((predictions - labels)^2) + + # After. + @R.function + def orig(x, y, labels): + out = x + y + return R.sum((out - labels)^2) + + Parameters + ---------- + orig_func : Function + The function to be appended. + + loss_func : Function + The loss function. + + Returns + ------- + ret : Function + The result function. + """ + return _ffi_api.AppendLoss(orig_func, loss_func) # type: ignore diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 6ef392bb1e..e1d5bf50c1 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -273,54 +273,3 @@ def copy_with_new_params(func: Function) -> Function: The copied function. """ return _ffi_api.CopyWithNewParams(func) # type: ignore - - -def extend_func(orig_func: Function, ex_func: Function) -> Function: - """Extend a relax function by another given function. It will link orig_func with - ex_func and return a new function. - - In detail, the result function has the arguments list of orig_func and the combination - of their body, which passes the return values of orig_func as the arguments of ex_func. For - those arguments of ex_func which are not mapped to some return values, they will be lifted and - appended to the argument list of result function. - - This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some - sense. - - Note: the return value of orig_func will be bound to DataflowVar. So it is a bad idea to use - this util if the params of ex_func present in its R.output. - - Example: - - .. code-block:: python - # Before. - @R.function - def func1(a, b): - return a + b, a * b - - @R.function - def func2(c, d, e): - return d, c, c + e - - # After. func1_func2 = extend_func(orig_func=func1, ex_func=func2). - @R.function - def func1_func2(a, b, e): - c = a + b - d = a * b - return d, c, c + e - - Parameters - ---------- - orig_func : Function - The function to be extended. - - ex_func : Function - The function to be linked after the orig_func. - - Returns - ------- - ret : Function - The result function. - """ - - return _ffi_api.ExtendFunc(orig_func, ex_func) # type: ignore diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc new file mode 100644 index 0000000000..7f0e16296d --- /dev/null +++ b/src/relax/training/utils.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "utils.h" + +namespace tvm { +namespace relax { + +/*! \brief Helper to implement append loss.*/ +class AppendLossMutator : public ExprMutator { + public: + explicit AppendLossMutator(const SeqExpr& loss_body) : loss_body_(loss_body) {} + + Expr VisitExpr_(const SeqExprNode* seq_expr) override { + // mutate only the last block. + Array blocks; + for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { + CHECK(seq_expr->blocks[i].as()) + << "All blocks in original functions should be Dataflow Block"; + if (i < static_cast(seq_expr->blocks.size()) - 1) { + blocks.push_back(seq_expr->blocks[i]); + } else { + BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[i]); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + } + } + return SeqExpr(blocks, loss_body_->body); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + builder_->BeginDataflowBlock(); + // emit original bindings. + for (const auto& binding : block->bindings) { + this->VisitBinding(binding); + } + + ICHECK(orig_rets_var_.size() == orig_rets.size()); + for (int i = 0; i < static_cast(orig_rets_var_.size()); ++i) { + if (orig_rets_var_[i].defined()) { + builder_->EmitNormalized(VarBinding(orig_rets_var_[i].value(), orig_rets[i])); + } + } + + // emit blocks for loss function part. + for (BindingBlock block : loss_body_->blocks) { + CHECK(block.as()) + << "All blocks in loss functions should be Dataflow Block"; + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + } + + return builder_->EndBlock(); + } + + void VisitBinding_(const VarBindingNode* binding) override { + Var new_var = Downcast(this->VisitExpr(binding->var)); + Expr new_value = this->VisitExpr(binding->value); + builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + // remap orignal dataflow var + // TODO(chaofan): a better way to check whether new_ret_var should be dataflow + void RemapToDataflow(SeqExpr body) { + for (BindingBlock block : body->blocks) { + for (Binding binding : block->bindings) { + const auto* binding_node = binding.as(); + if (binding_node && !binding_node->var->IsInstance()) { + Var new_binding_var = DataflowVar( + binding_node->var->vid, GetStructInfo(binding_node->var), binding_node->var->span); + this->var_remap_[binding_node->var->vid] = new_binding_var; + } + } + } + } + + Array RemapLossParams(const Array& loss_func_params, Array new_params) { + for (int i = 0; i < static_cast(loss_func_params.size()); ++i) { + Var loss_param = loss_func_params[i]; + if (i < static_cast(orig_rets.size())) { + // map return value to loss param + if (const auto* var_node = orig_rets[i].as()) { + ICHECK(orig_rets[i].as()); + orig_rets_var_.push_back(NullOpt); + this->var_remap_[loss_param->vid] = GetRef(var_node); + } else { + Var new_ret_var = + DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), GetStructInfo(orig_rets[i])); + orig_rets_var_.push_back(new_ret_var); + this->var_remap_[loss_param->vid] = new_ret_var; + } + } else { + // append to the param list + Var new_loss_param = Var(loss_param->vid, GetStructInfo(loss_param), loss_param->span); + this->var_remap_[loss_param->vid] = new_loss_param; + new_params.push_back(new_loss_param); + } + } + return new_params; + } + + Array orig_rets; + + private: + SeqExpr loss_body_; + Array> orig_rets_var_; +}; + +/*! + * \brief Local helper to append a specified loss function after the original function. + * \param orig_func The function to be appended. + * \param loss_func The loss function. + * \return The result function after appended. + */ +Function AppendLoss(Function orig_func, Function loss_func) { + CHECK(orig_func->body->IsInstance()) + << "the body of the original function is not SeqExpr."; + CHECK(loss_func->body->IsInstance()) + << "the body of the loss function is not SeqExpr."; + + auto param_copied_func = CopyWithNewParams(orig_func); + auto seq_expr = Downcast(param_copied_func->body); + + AppendLossMutator mutator(Downcast(loss_func->body)); + mutator.RemapToDataflow(seq_expr); + // Get the orignal rets. If it is a Tuple, unpack it. + if (orig_func->ret_struct_info.as()) { + const auto* tuple_node = seq_expr->body.as(); + ICHECK(tuple_node != nullptr); + for (Expr field : tuple_node->fields) { + mutator.orig_rets.push_back(mutator.VisitExpr(field)); + } + } else { + mutator.orig_rets.push_back(mutator.VisitExpr(seq_expr->body)); + } + + CHECK(loss_func->params.size() >= mutator.orig_rets.size()) + << "The number of return values of original functions should be greater than the number of " + "parameters of loss function"; + + auto new_params = mutator.RemapLossParams(loss_func->params, param_copied_func->params); + Expr new_body = mutator.VisitExpr(seq_expr); + return Function(new_params, new_body, loss_func->ret_struct_info, param_copied_func->attrs); +} + +TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h new file mode 100644 index 0000000000..ae8c5a7e18 --- /dev/null +++ b/src/relax/training/utils.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_TRAINING_UTILS_H_ +#define TVM_RELAX_TRAINING_UTILS_H_ + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Local helper to append a specified loss function after the original function. + * + * Notice: + * 1. This uitl is dedicated to loss functions, not for general purposes. + * 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in + * some sense. + * + * \param orig_func The function to be appended. + * \param loss_func The loss function. + * \return The result function after appended. + */ +TVM_DLL Function AppendLoss(Function orig_func, Function loss_func); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRAINING_UTILS_H_ diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 0c812e8f95..a77e4342e2 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -105,142 +105,5 @@ Function CopyWithNewParams(Function func) { return FunctionCopier::Transform(fun TVM_REGISTER_GLOBAL("relax.CopyWithNewParams").set_body_typed(CopyWithNewParams); -/*! \brief Helper to implement extend function.*/ -class ExtendFuncMutator : public ExprMutator { - public: - explicit ExtendFuncMutator(const SeqExpr& ex_body) : ex_body_(ex_body) {} - - Expr VisitExpr_(const SeqExprNode* seq_expr) override { - // mutate only the last block. - Array blocks; - for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { - if (i < static_cast(seq_expr->blocks.size()) - 1) { - blocks.push_back(seq_expr->blocks[i]); - } else { - BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[i]); - if (!new_block->bindings.empty()) { - blocks.push_back(new_block); - } - } - } - this->VisitExpr(seq_expr->body); - return SeqExpr(blocks, ex_body_->body); - } - - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { - builder_->BeginDataflowBlock(); - // emit original bindings. - for (const auto& binding : block->bindings) { - this->VisitBinding(binding); - } - - ICHECK(orig_rets_var_.size() == orig_rets.size()); - for (int i = 0; i < static_cast(orig_rets_var_.size()); ++i) { - if (orig_rets_var_[i].defined()) { - builder_->EmitNormalized(VarBinding(orig_rets_var_[i].value(), orig_rets[i])); - } - } - - // emit blocks for extend part. - for (BindingBlock block : ex_body_->blocks) { - for (Binding binding : block->bindings) { - this->VisitBinding(binding); - } - } - - return builder_->EndBlock(); - } - - void VisitBinding_(const VarBindingNode* binding) override { - Var new_var = Downcast(this->VisitExpr(binding->var)); - Expr new_value = this->VisitExpr(binding->value); - builder_->EmitNormalized(VarBinding(new_var, new_value)); - } - - // remap orignal dataflow var - // TODO(chaofan): a better way to check whether new_ret_var should be dataflow - void RemapToDataflow(SeqExpr body) { - for (BindingBlock block : body->blocks) { - for (Binding binding : block->bindings) { - const auto* binding_node = binding.as(); - if (binding_node && !binding_node->var->IsInstance()) { - Var new_binding_var = DataflowVar( - binding_node->var->vid, GetStructInfo(binding_node->var), binding_node->var->span); - this->var_remap_[binding_node->var->vid] = new_binding_var; - } - } - } - } - - Array RemapExParams(const Array& ex_func_params, Array new_params) { - for (int i = 0; i < static_cast(ex_func_params.size()); ++i) { - Var ex_param = ex_func_params[i]; - if (i < static_cast(orig_rets.size())) { - // map return value to ex param - if (const auto* var_node = orig_rets[i].as()) { - ICHECK(orig_rets[i].as()); - orig_rets_var_.push_back(NullOpt); - this->var_remap_[ex_param->vid] = GetRef(var_node); - } else { - Var new_ret_var = - DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), GetStructInfo(orig_rets[i])); - orig_rets_var_.push_back(new_ret_var); - this->var_remap_[ex_param->vid] = new_ret_var; - } - } else { - // append to the param list - Var new_ex_param = Var(ex_param->vid, GetStructInfo(ex_param), ex_param->span); - this->var_remap_[ex_param->vid] = new_ex_param; - new_params.push_back(new_ex_param); - } - } - return new_params; - } - - Array orig_rets; - - private: - SeqExpr ex_body_; - Array> orig_rets_var_; -}; - -/*! - * \brief Extend a relax function by another given function. - * \param orig_func The function to be extended. - * \param ex_func The function to be linked after the orig_func. - * \return The result function after extending. - */ -Function ExtendFunc(Function orig_func, Function ex_func) { - CHECK(orig_func->body->IsInstance()) - << "the body of the original function is not SeqExpr."; - CHECK(ex_func->body->IsInstance()) << "the body of the ex function is not SeqExpr."; - - auto param_copied_func = CopyWithNewParams(orig_func); - auto seq_expr = Downcast(param_copied_func->body); - - ExtendFuncMutator mutator(Downcast(ex_func->body)); - mutator.RemapToDataflow(seq_expr); - // Get the orignal rets. If it is a Tuple, unpack it. - if (orig_func->ret_struct_info.as()) { - const auto* tuple_node = seq_expr->body.as(); - ICHECK(tuple_node != nullptr); - for (Expr field : tuple_node->fields) { - mutator.orig_rets.push_back(mutator.VisitExpr(field)); - } - } else { - mutator.orig_rets.push_back(mutator.VisitExpr(seq_expr->body)); - } - - CHECK(ex_func->params.size() >= mutator.orig_rets.size()) - << "The number of return values of original functions should be greater than the number of " - "parameters of ex function"; - - auto new_params = mutator.RemapExParams(ex_func->params, param_copied_func->params); - Expr new_body = mutator.VisitExpr(seq_expr); - return Function(new_params, new_body, ex_func->ret_struct_info, param_copied_func->attrs); -} - -TVM_REGISTER_GLOBAL("relax.ExtendFunc").set_body_typed(ExtendFunc); - } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py new file mode 100644 index 0000000000..025c2e6241 --- /dev/null +++ b/tests/python/relax/test_training_loss.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm.testing +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R + + +@R.function +def forward( + x: R.Tensor((2, 4), dtype="float32"), + w: R.Tensor((4, 4), dtype="float32"), + b: R.Tensor((2, 4), dtype="float32"), +) -> R.Tensor((2, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w) + out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) + R.output(out) + return out + + +@pytest.mark.skip("Waiting for the operator of R.abs") +def test_l1_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "float32") + l1_loss = relax.training.L1Loss()(predictions, targets) + + +@pytest.mark.skip("Waiting for the operator of R.abs") +def test_l1_loss_append(): + pass + + +def test_mse_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "float32") + mse_loss = relax.training.MSELoss()(predictions, targets) + + # fmt: off + @R.function + def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): + # function attr dict + R.func_attr({"global_symbol": "mse_loss"}) + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), dtype="float32") = R.multiply(lv, lv) + lv2: R.Tensor((5,), dtype="float32") = R.mean(lv1, axis=[0], keepdims=False) + gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(mse_loss, expected) + + +def test_mse_loss_append(): + s = forward.ret_struct_info + mse_loss = relax.training.MSELoss(reduction="sum")(s, s) + forward_with_loss = relax.training.utils.append_loss(forward, mse_loss) + + # fmt: off + @R.function + def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), dtype="float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), dtype="float32") = R.multiply(lv1, lv1) + gv: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(forward_with_loss, expected) + + +def test_cross_entropy_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N,), "int64") + weights = relax.TensorStructInfo((C,), "float32") + cross_entropy_loss = relax.training.CrossEntropyLoss( + reduction="sum", ignore_index=1, weights=weights + )(predictions, targets) + + # fmt: off + @R.function + def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3,), dtype="int64"), weights: R.Tensor((5,), dtype="float32")) -> R.Tensor((), dtype="float32"): + # function attr dict + R.func_attr({"global_symbol": "cross_entropy_loss"}) + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv, targets, weights, reduction="sum", ignore_index=1) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(cross_entropy_loss, expected) + + +def test_cross_entropy_loss_without_weights(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N,), "int64") + cross_entropy_loss = relax.training.CrossEntropyLoss()(predictions, targets) + + # fmt: off + @R.function + def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3,), dtype="int64")) -> R.Tensor((), dtype="float32"): + # function attr dict + R.func_attr({"global_symbol": "cross_entropy_loss"}) + # block 0 + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv, targets, reduction="mean", ignore_index=-100) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(cross_entropy_loss, expected) + + +def test_cross_entropy_loss_append(): + s = forward.ret_struct_info + N = s.shape[0] + C = s.shape[1] + targets = relax.TensorStructInfo((N,), "int64") + weights = relax.TensorStructInfo((C,), "float32") + cross_entropy_loss = relax.training.CrossEntropyLoss( + reduction="sum", ignore_index=1, weights=weights + )(s, targets) + forward_with_loss = relax.training.utils.append_loss(forward, cross_entropy_loss) + + # fmt: off + @R.function + def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2,), dtype="int64"), weights: R.Tensor((4,), dtype="float32")) -> R.Tensor((), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), dtype="float32") = R.nn.log_softmax(out, axis=-1) + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv1, targets, weights, reduction="sum", ignore_index=1) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(forward_with_loss, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_training_utils.py b/tests/python/relax/test_training_utils.py new file mode 100644 index 0000000000..a885eedf97 --- /dev/null +++ b/tests/python/relax/test_training_utils.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm.testing +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R + + +def test_append_loss_basic_extend(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.sum(y) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss(arg1: R.Tensor((), dtype="float32"), arg2: R.Tensor((), dtype="float32")): + with R.dataflow(): + gv0 = R.add(arg1, arg2) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=0): + # block 0 + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv01: R.Tensor((), dtype="float32") = R.add(gv0, gv1) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +def test_append_loss_extra_params(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.add(x, x) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss( + arg1: R.Tensor((), dtype="float32"), + arg2: R.Tensor((3, 3), dtype="float32"), + arg3: R.Tensor((3, 3), dtype="float32"), + ): + with R.dataflow(): + gv0 = R.add(arg2, arg3) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=2): + # block 0 + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv01: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +def test_append_loss_nested_tuple(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.add(x, x) + gv1 = R.sum(y) + gv2 = R.add(x, y) + R.output(gv0, gv1, gv2) + return (gv0, gv1), gv2 + + @R.function + def loss( + arg1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")), + arg2: R.Tensor((), dtype="float32"), + ): + with R.dataflow(): + arg10 = arg1[0] + gv0 = R.add(arg10, arg2) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor((3, 3), dtype="float32"): + # block 0 + with R.dataflow(): + gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv2: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + ret_0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")) = ( + gv0, + gv1, + ) + arg10: R.Tensor((3, 3), dtype="float32") = ret_0[0] + gv01: R.Tensor((3, 3), dtype="float32") = R.add(arg10, gv2) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index c43b23a519..61502d737d 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -34,119 +34,5 @@ def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): assert before_var != after_var -def test_extend_func_basic_extend(): - @R.function - def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): - with R.dataflow(): - gv0 = R.sum(x) - gv1 = R.sum(y) - R.output(gv0, gv1) - return gv0, gv1 - - @R.function - def ex(arg1: R.Tensor((), dtype="float32"), arg2: R.Tensor((), dtype="float32")): - R.func_attr({"global_symbol": "ex"}) - with R.dataflow(): - gv0 = R.add(arg1, arg2) - R.output(gv0) - return gv0 - - @R.function - def orig_ex( - x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") - ) -> R.Tensor(None, dtype="float32", ndim=0): - # block 0 - with R.dataflow(): - gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) - gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) - gv01: R.Tensor((), dtype="float32") = R.add(gv0, gv1) - R.output(gv01) - return gv01 - - after = relax.utils.extend_func(orig, ex) - assert_structural_equal(after, orig_ex) - - -def test_extend_func_extra_params(): - @R.function - def orig(x: R.Tensor((3, 3), dtype="float32")): - with R.dataflow(): - gv0 = R.sum(x) - gv1 = R.add(x, x) - R.output(gv0, gv1) - return gv0, gv1 - - @R.function - def ex( - arg1: R.Tensor((), dtype="float32"), - arg2: R.Tensor((3, 3), dtype="float32"), - arg3: R.Tensor((3, 3), dtype="float32"), - ): - R.func_attr({"global_symbol": "ex"}) - with R.dataflow(): - gv0 = R.add(arg2, arg3) - R.output(gv0) - return gv0 - - @R.function - def orig_ex( - x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32") - ) -> R.Tensor(None, dtype="float32", ndim=2): - # block 0 - with R.dataflow(): - gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) - gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) - gv01: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3) - R.output(gv01) - return gv01 - - after = relax.utils.extend_func(orig, ex) - assert_structural_equal(after, orig_ex) - - -def test_extend_func_nested_tuple(): - @R.function - def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): - with R.dataflow(): - gv0 = R.add(x, x) - gv1 = R.sum(y) - gv2 = R.add(x, y) - R.output(gv0, gv1, gv2) - return (gv0, gv1), gv2 - - @R.function - def ex( - arg1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")), - arg2: R.Tensor((), dtype="float32"), - ): - R.func_attr({"global_symbol": "ex"}) - with R.dataflow(): - arg10 = arg1[0] - gv0 = R.add(arg10, arg2) - R.output(gv0) - return gv0 - - @R.function - def orig_ex( - x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") - ) -> R.Tensor((3, 3), dtype="float32"): - # block 0 - with R.dataflow(): - gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, x) - gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) - gv2: R.Tensor((3, 3), dtype="float32") = R.add(x, y) - ret_0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")) = ( - gv0, - gv1, - ) - arg10: R.Tensor((3, 3), dtype="float32") = ret_0[0] - gv01: R.Tensor((3, 3), dtype="float32") = R.add(arg10, gv2) - R.output(gv01) - return gv01 - - after = relax.utils.extend_func(orig, ex) - assert_structural_equal(after, orig_ex) - - if __name__ == "__main__": tvm.testing.main() From 8e4592a04e5126857a991205cf740a1f7bc5c314 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 03:46:37 +0800 Subject: [PATCH 05/17] fix l1_loss --- python/tvm/relax/training/loss.py | 6 ++-- tests/python/relax/test_training_loss.py | 40 ++++++++++++++++++++---- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 8d468501ec..72dcfdffad 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -21,7 +21,7 @@ from tvm import relax from ..expr import Expr, Var, Function, StructInfo -from ..op import sum, mean, subtract, multiply +from ..op import abs, sum, mean, subtract, multiply from ..op.nn import log_softmax, nll_loss @@ -81,7 +81,7 @@ def _with_reduction(self, expr: Expr): if self.reduction == "sum": expr = sum(expr) elif self.reduction == "mean": - expr = sum(mean(expr, axis=0)) + expr = mean(expr) else: assert self.reduction == "none" return expr @@ -111,7 +111,7 @@ def __call__( with bb.function(self.loss_name, [predictions, targets]): with bb.dataflow(): - lv = abs(subtract(predictions, targets)) # TODO: R.abs + lv = abs(subtract(predictions, targets)) loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 025c2e6241..60e83138f1 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest import tvm.testing from tvm import relax from tvm.ir.base import assert_structural_equal @@ -34,7 +33,6 @@ def forward( return out -@pytest.mark.skip("Waiting for the operator of R.abs") def test_l1_loss(): N = 3 C = 5 @@ -42,10 +40,41 @@ def test_l1_loss(): targets = relax.TensorStructInfo((N, C), "float32") l1_loss = relax.training.L1Loss()(predictions, targets) + # fmt: off + @R.function + def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): + R.func_attr({"global_symbol": "l1_loss"}) + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), dtype="float32") = R.abs(lv) + gv: R.Tensor((), dtype="float32") = R.mean(lv1, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(l1_loss, expected) + -@pytest.mark.skip("Waiting for the operator of R.abs") def test_l1_loss_append(): - pass + s = forward.ret_struct_info + l1_loss = relax.training.L1Loss(reduction="sum")(s, s) + forward_with_loss = relax.training.utils.append_loss(forward, l1_loss) + + # fmt: off + @R.function + def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), dtype="float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), dtype="float32") = R.abs(lv1) + gv: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + R.output(gv) + return gv + # fmt: on + + assert_structural_equal(forward_with_loss, expected) def test_mse_loss(): @@ -64,8 +93,7 @@ def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor(( with R.dataflow(): lv: R.Tensor((3, 5), dtype="float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), dtype="float32") = R.multiply(lv, lv) - lv2: R.Tensor((5,), dtype="float32") = R.mean(lv1, axis=[0], keepdims=False) - gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = R.mean(lv1, axis=None, keepdims=False) R.output(gv) return gv # fmt: on From 5c311cda266565e71d16601e500a4aa55945ac7d Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 03:50:10 +0800 Subject: [PATCH 06/17] lint --- python/tvm/relax/training/loss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 72dcfdffad..859553cebc 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-builtin """Loss functions library for relax.""" -from typing import Any, List, Optional, Union +from typing import Optional, Union from tvm import relax from ..expr import Expr, Var, Function, StructInfo @@ -78,6 +78,13 @@ def __call__(self) -> Function: raise NotImplementedError() def _with_reduction(self, expr: Expr): + """Add a reduction to the final loss. + + Parameters + ---------- + expr : Expr + The loss expr. + """ if self.reduction == "sum": expr = sum(expr) elif self.reduction == "mean": From ccc5f8acc13d3cfe682a1ae44fb88b1c2527f9a7 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 03:53:37 +0800 Subject: [PATCH 07/17] lint --- python/tvm/relax/training/loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 859553cebc..a1dc274cfe 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin, invalid-name """Loss functions library for relax.""" from typing import Optional, Union @@ -104,7 +104,7 @@ class L1Loss(Loss): """ def __init__(self, reduction: str = "mean") -> None: - super(L1Loss, self).__init__("l1_loss", reduction) + super().__init__("l1_loss", reduction) def __call__( self, @@ -135,7 +135,7 @@ class MSELoss(Loss): """ def __init__(self, reduction: str = "mean") -> None: - super(MSELoss, self).__init__("mse_loss", reduction) + super().__init__("mse_loss", reduction) def __call__( self, @@ -180,7 +180,7 @@ def __init__( ignore_index: int = -100, weights: Optional[Union[Var, StructInfo]] = None, ) -> None: - super(CrossEntropyLoss, self).__init__("cross_entropy_loss", reduction) + super().__init__("cross_entropy_loss", reduction) self.ignore_index = ignore_index if weights: self.weights = _create_param_var(weights, "weights") From effd72c48982ac2d71931e4306cc84d2f1627a81 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 04:13:38 +0800 Subject: [PATCH 08/17] doc --- python/tvm/relax/training/loss.py | 58 +++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index a1dc274cfe..e50f440183 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -63,20 +63,6 @@ def __init__(self, loss_name: str, reduction: str = "mean") -> None: if self.reduction not in valid_reductions: raise ValueError("Reduction can only be one of these values: ", valid_reductions) - def __call__(self) -> Function: - """Calling a loss will get its relax function. - - Usually it has some parameters with type Union[Var, StructInfo]. It means - the necessary inputs of the loss function. If a struct info is given, it will - construct a corresponding Var using the struct info; if a Var is given, it will - directly use this Var as the param. - - Returns - ---------- - The relax function of the loss with the loss name as its global symbol. - """ - raise NotImplementedError() - def _with_reduction(self, expr: Expr): """Add a reduction to the final loss. @@ -111,6 +97,20 @@ def __call__( predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], ) -> Function: + """Get the relax function of L1Loss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + Returns + ---------- + The relax function of L1Loss with the loss name as its global symbol. + """ bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") @@ -142,6 +142,20 @@ def __call__( predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], ) -> Function: + """Get the relax function of MSELoss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + Returns + ---------- + The relax function of MSELoss with the loss name as its global symbol. + """ bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") @@ -158,7 +172,7 @@ def __call__( class CrossEntropyLoss(Loss): - """CrossEntropyLoss. + """CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. Parameters ---------- @@ -192,6 +206,20 @@ def __call__( predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], ) -> Function: + """Get the relax function of CrossEntropyLoss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + Returns + ---------- + The relax function of CrossEntropyLoss with the loss name as its global symbol. + """ bb = relax.BlockBuilder() predictions = _create_param_var(predictions, "predictions") From 4d65fb737c71487f2f31245170f9ec832c2ca852 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 22:03:56 +0800 Subject: [PATCH 09/17] upd --- python/tvm/relax/training/loss.py | 22 +++++++-------- tests/python/relax/test_training_loss.py | 34 +++++++++++------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index e50f440183..7472858e05 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -32,7 +32,7 @@ def _create_param_var(param: Union[Var, StructInfo], param_name) -> Var: if isinstance(param, StructInfo): param = Var(param_name, param) assert isinstance(param, Var) - return param + return Var(param.name_hint, param.struct_info) class Loss: @@ -179,9 +179,6 @@ class CrossEntropyLoss(Loss): reduction : str See the doc of Loss. - weights : Optional[Union[Var, StructInfo]] - a manual rescaling weight given to each class. It has to be a Tensor of size C. - ignore_index : int Specifies a target value that is ignored and does not contribute to the input gradient. """ @@ -192,19 +189,15 @@ def __init__( self, reduction: str = "mean", ignore_index: int = -100, - weights: Optional[Union[Var, StructInfo]] = None, ) -> None: super().__init__("cross_entropy_loss", reduction) self.ignore_index = ignore_index - if weights: - self.weights = _create_param_var(weights, "weights") - else: - self.weights = None def __call__( self, predictions: Union[Var, StructInfo], targets: Union[Var, StructInfo], + weights: Optional[Union[Var, StructInfo]] = None, ) -> Function: """Get the relax function of CrossEntropyLoss. If the parameters are struct info, it will create corresponding variables. @@ -213,9 +206,13 @@ def __call__( ---------- predictions : Union[Var, StructInfo] The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] The ground truth in the calculation of loss. + weights : Optional[Union[Var, StructInfo]] + a manual rescaling weight given to each class. It has to be a Tensor of size C. + Returns ---------- The relax function of CrossEntropyLoss with the loss name as its global symbol. @@ -226,14 +223,15 @@ def __call__( targets = _create_param_var(targets, "targets") arg_list = [predictions, targets] - if self.weights: - arg_list.append(self.weights) + if weights: + weights = _create_param_var(weights, "weights") + arg_list.append(weights) with bb.function(self.loss_name, arg_list): with bb.dataflow(): logits = bb.emit(log_softmax(predictions)) loss = bb.emit_output( - nll_loss(logits, targets, self.weights, self.reduction, self.ignore_index) + nll_loss(logits, targets, weights, self.reduction, self.ignore_index) ) bb.emit_func_output(loss) diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 60e83138f1..db5701c4bd 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -38,7 +38,7 @@ def test_l1_loss(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N, C), "float32") - l1_loss = relax.training.L1Loss()(predictions, targets) + l1_loss = relax.training.L1Loss() # fmt: off @R.function @@ -52,13 +52,13 @@ def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor(( return gv # fmt: on - assert_structural_equal(l1_loss, expected) + assert_structural_equal(l1_loss(predictions, targets), expected) def test_l1_loss_append(): s = forward.ret_struct_info - l1_loss = relax.training.L1Loss(reduction="sum")(s, s) - forward_with_loss = relax.training.utils.append_loss(forward, l1_loss) + l1_loss = relax.training.L1Loss(reduction="sum") + forward_with_loss = relax.training.utils.append_loss(forward, l1_loss(s, s)) # fmt: off @R.function @@ -82,7 +82,7 @@ def test_mse_loss(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N, C), "float32") - mse_loss = relax.training.MSELoss()(predictions, targets) + mse_loss = relax.training.MSELoss() # fmt: off @R.function @@ -98,13 +98,13 @@ def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor(( return gv # fmt: on - assert_structural_equal(mse_loss, expected) + assert_structural_equal(mse_loss(predictions, targets), expected) def test_mse_loss_append(): s = forward.ret_struct_info - mse_loss = relax.training.MSELoss(reduction="sum")(s, s) - forward_with_loss = relax.training.utils.append_loss(forward, mse_loss) + mse_loss = relax.training.MSELoss(reduction="sum") + forward_with_loss = relax.training.utils.append_loss(forward, mse_loss(s, s)) # fmt: off @R.function @@ -129,9 +129,7 @@ def test_cross_entropy_loss(): predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N,), "int64") weights = relax.TensorStructInfo((C,), "float32") - cross_entropy_loss = relax.training.CrossEntropyLoss( - reduction="sum", ignore_index=1, weights=weights - )(predictions, targets) + cross_entropy_loss = relax.training.CrossEntropyLoss(reduction="sum", ignore_index=1) # fmt: off @R.function @@ -146,7 +144,7 @@ def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor(( return gv # fmt: on - assert_structural_equal(cross_entropy_loss, expected) + assert_structural_equal(cross_entropy_loss(predictions, targets, weights), expected) def test_cross_entropy_loss_without_weights(): @@ -154,7 +152,7 @@ def test_cross_entropy_loss_without_weights(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N,), "int64") - cross_entropy_loss = relax.training.CrossEntropyLoss()(predictions, targets) + cross_entropy_loss = relax.training.CrossEntropyLoss() # fmt: off @R.function @@ -169,7 +167,7 @@ def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor(( return gv # fmt: on - assert_structural_equal(cross_entropy_loss, expected) + assert_structural_equal(cross_entropy_loss(predictions, targets), expected) def test_cross_entropy_loss_append(): @@ -178,10 +176,10 @@ def test_cross_entropy_loss_append(): C = s.shape[1] targets = relax.TensorStructInfo((N,), "int64") weights = relax.TensorStructInfo((C,), "float32") - cross_entropy_loss = relax.training.CrossEntropyLoss( - reduction="sum", ignore_index=1, weights=weights - )(s, targets) - forward_with_loss = relax.training.utils.append_loss(forward, cross_entropy_loss) + cross_entropy_loss = relax.training.CrossEntropyLoss(reduction="sum", ignore_index=1) + forward_with_loss = relax.training.utils.append_loss( + forward, cross_entropy_loss(s, targets, weights) + ) # fmt: off @R.function From ffc09029f9f0ef3716122c8d4c63a813a8faf050 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 22:40:45 +0800 Subject: [PATCH 10/17] address comments --- python/tvm/relax/training/_ffi_api.py | 1 + python/tvm/relax/training/loss.py | 54 ++++++---- python/tvm/relax/training/utils.py | 37 ++++--- src/relax/training/utils.cc | 28 ++--- src/relax/training/utils.h | 8 +- tests/python/relax/test_training_loss.py | 122 +++++++++++++--------- tests/python/relax/test_training_utils.py | 2 +- 7 files changed, 149 insertions(+), 103 deletions(-) diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 34dc4a0669..70cb83fc0e 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -13,6 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations +# under the License. """FFI APIs for tvm.relax.training""" import tvm._ffi diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 7472858e05..f4f6b1aa1a 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-builtin, invalid-name """Loss functions library for relax.""" -from typing import Optional, Union +from typing import Optional, Union, Literal from tvm import relax from ..expr import Expr, Var, Function, StructInfo @@ -28,22 +28,23 @@ __all__ = ["L1Loss", "MSELoss", "CrossEntropyLoss"] -def _create_param_var(param: Union[Var, StructInfo], param_name) -> Var: +def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: if isinstance(param, StructInfo): param = Var(param_name, param) - assert isinstance(param, Var) + if not isinstance(param, Var): + raise TypeError("The type of param should be Var or StructInfo, but got " + type(param)) return Var(param.name_hint, param.struct_info) class Loss: - """Base class of all loss. + r"""Base class of all loss. Parameters ---------- loss_name : str The name of the loss function. - reduction : str + reduction : Literal["mean", "sum", "none"] The reduction method to apply to output. Can be "mean", "sum" or "none". none : no reduction will be applied, @@ -51,10 +52,7 @@ class Loss: sum : the output will be summed. """ - reduction: str - loss_name: str - - def __init__(self, loss_name: str, reduction: str = "mean") -> None: + def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None: self.loss_name = loss_name self.reduction = reduction @@ -63,7 +61,7 @@ def __init__(self, loss_name: str, reduction: str = "mean") -> None: if self.reduction not in valid_reductions: raise ValueError("Reduction can only be one of these values: ", valid_reductions) - def _with_reduction(self, expr: Expr): + def _with_reduction(self, expr: Expr) -> Expr: """Add a reduction to the final loss. Parameters @@ -81,15 +79,19 @@ def _with_reduction(self, expr: Expr): class L1Loss(Loss): - """Mean element-wise absolute value difference. + r"""Mean element-wise absolute value difference. Parameters ---------- - reduction : str - See the doc of Loss. + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. """ - def __init__(self, reduction: str = "mean") -> None: + def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: super().__init__("l1_loss", reduction) def __call__( @@ -126,15 +128,19 @@ def __call__( class MSELoss(Loss): - """Measures the element-wise mean squared error. + r"""Measures the element-wise mean squared error. Parameters ---------- - reduction : str - See the doc of Loss. + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. """ - def __init__(self, reduction: str = "mean") -> None: + def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: super().__init__("mse_loss", reduction) def __call__( @@ -172,12 +178,16 @@ def __call__( class CrossEntropyLoss(Loss): - """CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. + r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. Parameters ---------- - reduction : str - See the doc of Loss. + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. ignore_index : int Specifies a target value that is ignored and does not contribute to the input gradient. @@ -187,7 +197,7 @@ class CrossEntropyLoss(Loss): def __init__( self, - reduction: str = "mean", + reduction: Literal["mean", "sum", "none"] = "mean", ignore_index: int = -100, ) -> None: super().__init__("cross_entropy_loss", reduction) diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index 596ba9c176..8c947f3f48 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -37,25 +37,38 @@ def append_loss(orig_func: Function, loss_func: Function) -> Function: .. code-block:: python # Before. - @R.function - def orig(x, y): - out = x + y - return out + @R.function + def orig(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32")): + with R.dataflow(): + out = R.add(x, y) + R.output(out) + return out - @R.function - def loss(predictions, labels): - return R.sum((predictions - labels)^2) + @R.function + def loss(predictions: R.Tensor((2, 4), "float32"), labels: R.Tensor((2, 4), "float32")): + with R.dataflow(): + lv = R.subtract(predictions, labels) + lv1 = R.multiply(lv, lv) + gv = R.sum(lv1) + R.output(gv) + return gv # After. - @R.function - def orig(x, y, labels): - out = x + y - return R.sum((out - labels)^2) + @R.function + def expected(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32"), + labels: R.Tensor((2, 4), "float32")) -> R.Tensor((), "float32"): + with R.dataflow(): + out: R.Tensor((2, 4), "float32") = R.add(x, y) + lv: R.Tensor((2, 4), "float32") = R.subtract(out, labels) + lv1: R.Tensor((2, 4), "float32") = R.multiply(lv, lv) + gv: R.Tensor((), "float32") = R.sum(lv1) + R.output(gv) + return gv Parameters ---------- orig_func : Function - The function to be appended. + The function to be appended to. loss_func : Function The loss function. diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 7f0e16296d..deb8e03be2 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -32,7 +32,7 @@ class AppendLossMutator : public ExprMutator { Array blocks; for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { CHECK(seq_expr->blocks[i].as()) - << "All blocks in original functions should be Dataflow Block"; + << "All blocks in original functions should be Dataflow Block."; if (i < static_cast(seq_expr->blocks.size()) - 1) { blocks.push_back(seq_expr->blocks[i]); } else { @@ -60,10 +60,10 @@ class AppendLossMutator : public ExprMutator { } // emit blocks for loss function part. - for (BindingBlock block : loss_body_->blocks) { + for (const BindingBlock& block : loss_body_->blocks) { CHECK(block.as()) - << "All blocks in loss functions should be Dataflow Block"; - for (Binding binding : block->bindings) { + << "All blocks in loss functions should be Dataflow Block."; + for (const Binding& binding : block->bindings) { this->VisitBinding(binding); } } @@ -80,8 +80,8 @@ class AppendLossMutator : public ExprMutator { // remap orignal dataflow var // TODO(chaofan): a better way to check whether new_ret_var should be dataflow void RemapToDataflow(SeqExpr body) { - for (BindingBlock block : body->blocks) { - for (Binding binding : block->bindings) { + for (const BindingBlock& block : body->blocks) { + for (const Binding& binding : block->bindings) { const auto* binding_node = binding.as(); if (binding_node && !binding_node->var->IsInstance()) { Var new_binding_var = DataflowVar( @@ -126,15 +126,17 @@ class AppendLossMutator : public ExprMutator { /*! * \brief Local helper to append a specified loss function after the original function. - * \param orig_func The function to be appended. + * \param orig_func The function to be appended to. * \param loss_func The loss function. * \return The result function after appended. */ Function AppendLoss(Function orig_func, Function loss_func) { CHECK(orig_func->body->IsInstance()) - << "the body of the original function is not SeqExpr."; + << "The body of the original function is expected to be a SeqExpr, but got" + << orig_func->body->GetTypeKey(); CHECK(loss_func->body->IsInstance()) - << "the body of the loss function is not SeqExpr."; + << "The body of the loss function is expected to be a SeqExpr, but got" + << loss_func->body->GetTypeKey(); auto param_copied_func = CopyWithNewParams(orig_func); auto seq_expr = Downcast(param_copied_func->body); @@ -145,7 +147,7 @@ Function AppendLoss(Function orig_func, Function loss_func) { if (orig_func->ret_struct_info.as()) { const auto* tuple_node = seq_expr->body.as(); ICHECK(tuple_node != nullptr); - for (Expr field : tuple_node->fields) { + for (const Expr& field : tuple_node->fields) { mutator.orig_rets.push_back(mutator.VisitExpr(field)); } } else { @@ -154,11 +156,13 @@ Function AppendLoss(Function orig_func, Function loss_func) { CHECK(loss_func->params.size() >= mutator.orig_rets.size()) << "The number of return values of original functions should be greater than the number of " - "parameters of loss function"; + "parameters of loss function. Got " + << mutator.orig_rets.size() << " > " << loss_func->params.size(); auto new_params = mutator.RemapLossParams(loss_func->params, param_copied_func->params); Expr new_body = mutator.VisitExpr(seq_expr); - return Function(new_params, new_body, loss_func->ret_struct_info, param_copied_func->attrs); + return Function(std::move(new_params), std::move(new_body), loss_func->ret_struct_info, + param_copied_func->attrs); } TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index ae8c5a7e18..677729485b 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -18,8 +18,8 @@ */ /*! - * \file tvm/relax/utils.h - * \brief Utility classes and functions for working with the Relax IR. + * \file tvm/relax/training/utils.h + * \brief Utility classes and functions for relax training. */ #ifndef TVM_RELAX_TRAINING_UTILS_H_ #define TVM_RELAX_TRAINING_UTILS_H_ @@ -38,11 +38,11 @@ namespace relax { * 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in * some sense. * - * \param orig_func The function to be appended. + * \param orig_func The function to be appended to. * \param loss_func The loss function. * \return The result function after appended. */ -TVM_DLL Function AppendLoss(Function orig_func, Function loss_func); +Function AppendLoss(Function orig_func, Function loss_func); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index db5701c4bd..6880e0517f 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -17,18 +17,18 @@ import tvm.testing from tvm import relax from tvm.ir.base import assert_structural_equal -from tvm.script.parser import relax as R +from tvm.script import relax as R @R.function def forward( - x: R.Tensor((2, 4), dtype="float32"), - w: R.Tensor((4, 4), dtype="float32"), - b: R.Tensor((2, 4), dtype="float32"), -) -> R.Tensor((2, 4), dtype="float32"): + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), +) -> R.Tensor((2, 4), "float32"): with R.dataflow(): - lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w) - out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w) + out: R.Tensor((2, 4), "float32") = R.add(lv, b) R.output(out) return out @@ -40,17 +40,17 @@ def test_l1_loss(): targets = relax.TensorStructInfo((N, C), "float32") l1_loss = relax.training.L1Loss() - # fmt: off @R.function - def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") + ) -> R.Tensor((), "float32"): R.func_attr({"global_symbol": "l1_loss"}) with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.subtract(predictions, targets) - lv1: R.Tensor((3, 5), dtype="float32") = R.abs(lv) - gv: R.Tensor((), dtype="float32") = R.mean(lv1, axis=None, keepdims=False) + lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), "float32") = R.abs(lv) + gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, keepdims=False) R.output(gv) return gv - # fmt: on assert_structural_equal(l1_loss(predictions, targets), expected) @@ -60,19 +60,22 @@ def test_l1_loss_append(): l1_loss = relax.training.L1Loss(reduction="sum") forward_with_loss = relax.training.utils.append_loss(forward, l1_loss(s, s)) - # fmt: off @R.function - def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2, 4), "float32"), + ) -> R.Tensor((), "float32"): # block 0 with R.dataflow(): - lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") - out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) - lv1: R.Tensor((2, 4), dtype="float32") = R.subtract(out, targets) - lv11: R.Tensor((2, 4), dtype="float32") = R.abs(lv1) - gv: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), "float32") = R.abs(lv1) + gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, keepdims=False) R.output(gv) return gv - # fmt: on assert_structural_equal(forward_with_loss, expected) @@ -84,19 +87,19 @@ def test_mse_loss(): targets = relax.TensorStructInfo((N, C), "float32") mse_loss = relax.training.MSELoss() - # fmt: off @R.function - def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3, 5), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") + ) -> R.Tensor((), "float32"): # function attr dict R.func_attr({"global_symbol": "mse_loss"}) # block 0 with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.subtract(predictions, targets) - lv1: R.Tensor((3, 5), dtype="float32") = R.multiply(lv, lv) - gv: R.Tensor((), dtype="float32") = R.mean(lv1, axis=None, keepdims=False) + lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv) + gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, keepdims=False) R.output(gv) return gv - # fmt: on assert_structural_equal(mse_loss(predictions, targets), expected) @@ -106,19 +109,22 @@ def test_mse_loss_append(): mse_loss = relax.training.MSELoss(reduction="sum") forward_with_loss = relax.training.utils.append_loss(forward, mse_loss(s, s)) - # fmt: off @R.function - def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2, 4), "float32"), + ) -> R.Tensor((), "float32"): # block 0 with R.dataflow(): - lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") - out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) - lv1: R.Tensor((2, 4), dtype="float32") = R.subtract(out, targets) - lv11: R.Tensor((2, 4), dtype="float32") = R.multiply(lv1, lv1) - gv: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), "float32") = R.multiply(lv1, lv1) + gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, keepdims=False) R.output(gv) return gv - # fmt: on assert_structural_equal(forward_with_loss, expected) @@ -131,18 +137,22 @@ def test_cross_entropy_loss(): weights = relax.TensorStructInfo((C,), "float32") cross_entropy_loss = relax.training.CrossEntropyLoss(reduction="sum", ignore_index=1) - # fmt: off @R.function - def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3,), dtype="int64"), weights: R.Tensor((5,), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + predictions: R.Tensor((3, 5), "float32"), + targets: R.Tensor((3,), "int64"), + weights: R.Tensor((5,), "float32"), + ) -> R.Tensor((), "float32"): # function attr dict R.func_attr({"global_symbol": "cross_entropy_loss"}) # block 0 with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.nn.log_softmax(predictions, axis=-1) - gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv, targets, weights, reduction="sum", ignore_index=1) + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv, targets, weights, reduction="sum", ignore_index=1 + ) R.output(gv) return gv - # fmt: on assert_structural_equal(cross_entropy_loss(predictions, targets, weights), expected) @@ -154,18 +164,20 @@ def test_cross_entropy_loss_without_weights(): targets = relax.TensorStructInfo((N,), "int64") cross_entropy_loss = relax.training.CrossEntropyLoss() - # fmt: off @R.function - def expected(predictions: R.Tensor((3, 5), dtype="float32"), targets: R.Tensor((3,), dtype="int64")) -> R.Tensor((), dtype="float32"): + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), "int64") + ) -> R.Tensor((), "float32"): # function attr dict R.func_attr({"global_symbol": "cross_entropy_loss"}) # block 0 with R.dataflow(): - lv: R.Tensor((3, 5), dtype="float32") = R.nn.log_softmax(predictions, axis=-1) - gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv, targets, reduction="mean", ignore_index=-100) + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv, targets, reduction="mean", ignore_index=-100 + ) R.output(gv) return gv - # fmt: on assert_structural_equal(cross_entropy_loss(predictions, targets), expected) @@ -181,18 +193,24 @@ def test_cross_entropy_loss_append(): forward, cross_entropy_loss(s, targets, weights) ) - # fmt: off @R.function - def expected(x: R.Tensor((2, 4), dtype="float32"), w: R.Tensor((4, 4), dtype="float32"), b: R.Tensor((2, 4), dtype="float32"), targets: R.Tensor((2,), dtype="int64"), weights: R.Tensor((4,), dtype="float32")) -> R.Tensor((), dtype="float32"): + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2,), "int64"), + weights: R.Tensor((4,), "float32"), + ) -> R.Tensor((), "float32"): # block 0 with R.dataflow(): - lv: R.Tensor((2, 4), dtype="float32") = R.matmul(x, w, out_dtype="") - out: R.Tensor((2, 4), dtype="float32") = R.add(lv, b) - lv1: R.Tensor((2, 4), dtype="float32") = R.nn.log_softmax(out, axis=-1) - gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(lv1, targets, weights, reduction="sum", ignore_index=1) + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.nn.log_softmax(out, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv1, targets, weights, reduction="sum", ignore_index=1 + ) R.output(gv) return gv - # fmt: on assert_structural_equal(forward_with_loss, expected) diff --git a/tests/python/relax/test_training_utils.py b/tests/python/relax/test_training_utils.py index a885eedf97..6b2fcd88a6 100644 --- a/tests/python/relax/test_training_utils.py +++ b/tests/python/relax/test_training_utils.py @@ -17,7 +17,7 @@ import tvm.testing from tvm import relax from tvm.ir.base import assert_structural_equal -from tvm.script.parser import relax as R +from tvm.script import relax as R def test_append_loss_basic_extend(): From 6dea3249c07a44e545adeed88345004779732309 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sat, 28 Jan 2023 22:48:28 +0800 Subject: [PATCH 11/17] lint --- python/tvm/relax/training/__init__.py | 2 +- python/tvm/relax/training/loss.py | 19 +++++++++++-------- src/relax/training/utils.h | 2 ++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/training/__init__.py b/python/tvm/relax/training/__init__.py index b9da9ad5c6..c3bda65860 100644 --- a/python/tvm/relax/training/__init__.py +++ b/python/tvm/relax/training/__init__.py @@ -20,4 +20,4 @@ from . import utils # loss functions -from .loss import * +from .loss import L1Loss, MSELoss, CrossEntropyLoss diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index f4f6b1aa1a..8b0b41e544 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -17,7 +17,13 @@ # pylint: disable=redefined-builtin, invalid-name """Loss functions library for relax.""" -from typing import Optional, Union, Literal +from typing import Optional, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + from tvm import relax from ..expr import Expr, Var, Function, StructInfo @@ -25,9 +31,6 @@ from ..op.nn import log_softmax, nll_loss -__all__ = ["L1Loss", "MSELoss", "CrossEntropyLoss"] - - def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: if isinstance(param, StructInfo): param = Var(param_name, param) @@ -36,7 +39,7 @@ def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: return Var(param.name_hint, param.struct_info) -class Loss: +class _Loss: r"""Base class of all loss. Parameters @@ -78,7 +81,7 @@ def _with_reduction(self, expr: Expr) -> Expr: return expr -class L1Loss(Loss): +class L1Loss(_Loss): r"""Mean element-wise absolute value difference. Parameters @@ -127,7 +130,7 @@ def __call__( return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) -class MSELoss(Loss): +class MSELoss(_Loss): r"""Measures the element-wise mean squared error. Parameters @@ -177,7 +180,7 @@ def __call__( return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) -class CrossEntropyLoss(Loss): +class CrossEntropyLoss(_Loss): r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. Parameters diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index 677729485b..82c5d4d461 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relax { From 5c37e5ff01be19e972c8d8e8ce4aa873652b6731 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Sun, 29 Jan 2023 19:32:57 +0800 Subject: [PATCH 12/17] adjust namespace --- python/tvm/relax/__init__.py | 5 +- python/tvm/relax/training/loss.py | 8 +- python/tvm/relax/training/optimizer.py | 128 ++++++++++++------------- 3 files changed, 71 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 4f82420bbf..a12bc1509b 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -89,4 +89,7 @@ ) # Training utils -from .training import optimizer +from .training import ( + optimizer, + loss, +) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 8b0b41e544..c9503afaeb 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -24,7 +24,7 @@ # isort: on -from tvm import relax +from ..block_builder import BlockBuilder from ..expr import Expr, Var, Function, StructInfo from ..op import abs, sum, mean, subtract, multiply @@ -116,7 +116,7 @@ def __call__( ---------- The relax function of L1Loss with the loss name as its global symbol. """ - bb = relax.BlockBuilder() + bb = BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") @@ -165,7 +165,7 @@ def __call__( ---------- The relax function of MSELoss with the loss name as its global symbol. """ - bb = relax.BlockBuilder() + bb = BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") @@ -230,7 +230,7 @@ def __call__( ---------- The relax function of CrossEntropyLoss with the loss name as its global symbol. """ - bb = relax.BlockBuilder() + bb = BlockBuilder() predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") diff --git a/python/tvm/relax/training/optimizer.py b/python/tvm/relax/training/optimizer.py index 729b861e2b..b5fd9365f7 100644 --- a/python/tvm/relax/training/optimizer.py +++ b/python/tvm/relax/training/optimizer.py @@ -22,12 +22,14 @@ import numpy as np # type: ignore import tvm -from tvm import relax as rx -from tvm.relax.struct_info import TensorStructInfo -from tvm.relax.transform.legalize_ops import LegalizeOps from tvm.runtime.container import tuple_object -from tvm.relax.op import add, subtract, multiply, divide, sqrt -from tvm.relax import Var, Function + +from ..vm import VirtualMachine, build +from ..block_builder import BlockBuilder +from ..struct_info import TensorStructInfo, TupleStructInfo +from ..transform.legalize_ops import LegalizeOps +from ..op import add, subtract, multiply, divide, sqrt +from ..expr import const, Var, Function, TupleGetItem, Tuple as RxTuple # TODO(chaofan, yixin): Migrate key logics to C++ class Optimizer: @@ -74,7 +76,7 @@ class Optimizer: _dtype: str # these attributes are for the building and running process of the optimizer function - _vm_module: rx.VirtualMachine + _vm_module: VirtualMachine _target: Union[str, tvm.target.Target] _device: Union[tvm.runtime.Device, List[tvm.runtime.Device]] @@ -248,8 +250,8 @@ def __call__( mod = tvm.IRModule({self.name: self.get_function()}) # pylint: disable=not-callable lowered_mod = LegalizeOps()(mod) # type: ignore - executable = rx.vm.build(lowered_mod, self._target) - self._vm_module = rx.VirtualMachine(executable, self._device) + executable = build(lowered_mod, self._target) + self._vm_module = VirtualMachine(executable, self._device) new_params, self.state = self._vm_module[self.name](params_adt, grads_adt, self.state) return new_params @@ -333,38 +335,38 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) - state_var = Var("optim_states", rx.TupleStructInfo([rx.TensorStructInfo((), "int64")])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) + state_var = Var("optim_states", TupleStructInfo([TensorStructInfo((), "int64")])) # constants - lr = rx.const(self.lr, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - one = rx.const(1, "int64") + lr = const(self.lr, dtype) + weight_decay = const(self.weight_decay, dtype) + one = const(1, "int64") - builder = rx.BlockBuilder() + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new, state_list_new = [], [] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one), "num_steps_new") state_list_new.append(num_steps_new) # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") p_new = builder.emit(subtract(p, multiply(lr, g)), name + "_new") param_list_new.append(p_new) # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] @@ -471,36 +473,36 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) state_var = Var( "optim_states", - rx.TupleStructInfo([rx.TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), + TupleStructInfo([TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), ) # constants - lr = rx.const(self.lr, dtype) - momentum = rx.const(self.momentum, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - dampening_inv = rx.const(_high_precision_subtract(1, self.dampening), dtype) - one = rx.const(1, "int64") + lr = const(self.lr, dtype) + momentum = const(self.momentum, dtype) + weight_decay = const(self.weight_decay, dtype) + dampening_inv = const(_high_precision_subtract(1, self.dampening), dtype) + one = const(1, "int64") - builder = rx.BlockBuilder() + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new, state_list_new = [], [] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one), "num_steps_new") state_list_new.append(num_steps_new) # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") - v = builder.emit(rx.TupleGetItem(state_var, i + 1), name + "_v") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") + v = builder.emit(TupleGetItem(state_var, i + 1), name + "_v") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") damp_g = multiply(dampening_inv, g) if self.dampening else g @@ -515,8 +517,8 @@ def get_function(self) -> Function: state_list_new.append(v_new) # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] @@ -638,15 +640,15 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) state_var = Var( "optim_states", - rx.TupleStructInfo( + TupleStructInfo( [ - rx.TensorStructInfo((), "int64"), - rx.TensorStructInfo((), dtype), - rx.TensorStructInfo((), dtype), + TensorStructInfo((), "int64"), + TensorStructInfo((), dtype), + TensorStructInfo((), dtype), *(p.struct_info for p in plist), *(p.struct_info for p in plist), ] @@ -654,42 +656,38 @@ def get_function(self) -> Function: ) # constants - lr = rx.const(self.lr, dtype) - beta1 = rx.const(self.beta1, dtype) - beta2 = rx.const(self.beta2, dtype) - beta1_inv = rx.const(_high_precision_subtract(1, self.beta1), dtype) - beta2_inv = rx.const(_high_precision_subtract(1, self.beta2), dtype) - eps = rx.const(self.eps, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - one_int = rx.const(1, "int64") - one_float = rx.const(1, dtype) - - builder = rx.BlockBuilder() + lr = const(self.lr, dtype) + beta1 = const(self.beta1, dtype) + beta2 = const(self.beta2, dtype) + beta1_inv = const(_high_precision_subtract(1, self.beta1), dtype) + beta2_inv = const(_high_precision_subtract(1, self.beta2), dtype) + eps = const(self.eps, dtype) + weight_decay = const(self.weight_decay, dtype) + one_int = const(1, "int64") + one_float = const(1, dtype) + + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new = [] state_list_new = [None] * (len_param * 2 + 3) # type: List[Optional[Var]] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one_int), "num_steps_new") state_list_new[0] = num_steps_new - beta1_prod = builder.emit( - multiply(rx.TupleGetItem(state_var, 1), beta1), "beta1_prod" - ) - beta2_prod = builder.emit( - multiply(rx.TupleGetItem(state_var, 2), beta2), "beta2_prod" - ) + beta1_prod = builder.emit(multiply(TupleGetItem(state_var, 1), beta1), "beta1_prod") + beta2_prod = builder.emit(multiply(TupleGetItem(state_var, 2), beta2), "beta2_prod") state_list_new[1] = beta1_prod state_list_new[2] = beta2_prod # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") - m = builder.emit(rx.TupleGetItem(state_var, i + 3), name + "_m") - v = builder.emit(rx.TupleGetItem(state_var, i + 3 + len_param), name + "_v") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") + m = builder.emit(TupleGetItem(state_var, i + 3), name + "_m") + v = builder.emit(TupleGetItem(state_var, i + 3 + len_param), name + "_v") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") m_new = builder.emit( @@ -714,7 +712,7 @@ def get_function(self) -> Function: state_list_new[i + 3 + len_param] = v_new # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] From 9291a07e6a8ef1ea0bb605e10de8d106b26ccb7f Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Mon, 30 Jan 2023 01:23:19 +0800 Subject: [PATCH 13/17] upd --- python/tvm/relax/__init__.py | 5 +---- python/tvm/relax/training/loss.py | 32 +++++++++++++++---------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index a12bc1509b..de9c7f75fc 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -89,7 +89,4 @@ ) # Training utils -from .training import ( - optimizer, - loss, -) +from .training import loss, optimizer diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index c9503afaeb..358ddea08f 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -55,14 +55,14 @@ class _Loss: sum : the output will be summed. """ - def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None: - self.loss_name = loss_name - self.reduction = reduction + _valid_reductions = ["mean", "sum", "none"] - valid_reductions = ["mean", "sum", "none"] + def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None: + self._loss_name = loss_name + self._reduction = reduction - if self.reduction not in valid_reductions: - raise ValueError("Reduction can only be one of these values: ", valid_reductions) + if self._reduction not in self._valid_reductions: + raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) def _with_reduction(self, expr: Expr) -> Expr: """Add a reduction to the final loss. @@ -72,12 +72,12 @@ def _with_reduction(self, expr: Expr) -> Expr: expr : Expr The loss expr. """ - if self.reduction == "sum": + if self._reduction == "sum": expr = sum(expr) - elif self.reduction == "mean": + elif self._reduction == "mean": expr = mean(expr) else: - assert self.reduction == "none" + raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) return expr @@ -121,13 +121,13 @@ def __call__( predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") - with bb.function(self.loss_name, [predictions, targets]): + with bb.function(self._loss_name, [predictions, targets]): with bb.dataflow(): lv = abs(subtract(predictions, targets)) loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) - return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) + return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) class MSELoss(_Loss): @@ -170,14 +170,14 @@ def __call__( predictions = _create_param_var(predictions, "predictions") targets = _create_param_var(targets, "targets") - with bb.function(self.loss_name, [predictions, targets]): + with bb.function(self._loss_name, [predictions, targets]): with bb.dataflow(): lv = subtract(predictions, targets) lv = multiply(lv, lv) loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) - return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) + return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) class CrossEntropyLoss(_Loss): @@ -240,12 +240,12 @@ def __call__( weights = _create_param_var(weights, "weights") arg_list.append(weights) - with bb.function(self.loss_name, arg_list): + with bb.function(self._loss_name, arg_list): with bb.dataflow(): logits = bb.emit(log_softmax(predictions)) loss = bb.emit_output( - nll_loss(logits, targets, weights, self.reduction, self.ignore_index) + nll_loss(logits, targets, weights, self._reduction, self.ignore_index) ) bb.emit_func_output(loss) - return bb.get()[self.loss_name].with_attr("global_symbol", self.loss_name) + return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) From e0287750fcc1fd54c62eb191177e9812427da9ea Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Mon, 30 Jan 2023 01:31:41 +0800 Subject: [PATCH 14/17] fix test --- python/tvm/relax/training/loss.py | 2 +- .../python/relax/test_op_gradient_numeric.py | 27 ----------- .../relax/test_training_optimizer_numeric.py | 47 ++++++++++--------- 3 files changed, 25 insertions(+), 51 deletions(-) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 358ddea08f..0e11e7958f 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -76,7 +76,7 @@ def _with_reduction(self, expr: Expr) -> Expr: expr = sum(expr) elif self._reduction == "mean": expr = mean(expr) - else: + elif self._reduction != "none": raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) return expr diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 6243b609cd..60e64803b8 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -132,7 +132,6 @@ def forward(*inputs): check_numerical_grads(forward, inputs_numpy, _tvm_to_numpy(result)) -@tvm.testing.parametrize_targets("llvm") def test_add(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -141,7 +140,6 @@ def test_add(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_subtract(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -150,7 +148,6 @@ def test_subtract(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_multiply(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -159,7 +156,6 @@ def test_multiply(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_permute_dims(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -167,7 +163,6 @@ def test_permute_dims(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_permute_dims_with_axes(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -181,13 +176,11 @@ def test_permute_dims_with_axes(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_relu(target, dev): data1_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.nn.relu, "relax.nn.relu", [data1_numpy], target, dev, (3, 3)) -@tvm.testing.parametrize_targets("llvm") def test_matmul_2_2(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 4)).astype(np.float32) @@ -196,7 +189,6 @@ def test_matmul_2_2(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_matmul_1_1(target, dev): data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) @@ -205,7 +197,6 @@ def test_matmul_1_1(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_matmul_1_4(target, dev): data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) @@ -214,7 +205,6 @@ def test_matmul_1_4(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_matmul_4_1(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (5,)).astype(np.float32) @@ -223,7 +213,6 @@ def test_matmul_4_1(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_matmul_5_4(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 1, 4, 5)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 2, 5, 4)).astype(np.float32) @@ -237,7 +226,6 @@ def test_matmul_5_4(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_softmax(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -245,7 +233,6 @@ def test_softmax(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_softmax_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -253,7 +240,6 @@ def test_softmax_with_axis(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_log_softmax(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -261,7 +247,6 @@ def test_log_softmax(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_log_softmax_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -269,13 +254,11 @@ def test_log_softmax_with_axis(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_sum(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.sum, "relax.sum", [data1_numpy], target, dev, ()) -@tvm.testing.parametrize_targets("llvm") def test_sum_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -283,7 +266,6 @@ def test_sum_with_axis(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_sum_keepdims(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -291,19 +273,16 @@ def test_sum_keepdims(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_sigmoid(target, dev): data_numpy = np.random.randint(1, 16, (3,)).astype(np.float32) relax_check_gradients(relax.op.sigmoid, "relax.sigmoid", [data_numpy], target, dev, (3,)) -@tvm.testing.parametrize_targets("llvm") def test_tanh(target, dev): data_numpy = np.random.randint(1, 16, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.tanh, "relax.tanh", [data_numpy], target, dev, (3, 3)) -@tvm.testing.parametrize_targets("llvm") def test_concat(target, dev): data_numpy1 = np.random.randint(1, 16, (3, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3, 4)).astype(np.float32) @@ -320,7 +299,6 @@ def test_concat(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_split_indices(target, dev): data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32) relax_check_gradients( @@ -335,7 +313,6 @@ def test_split_indices(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_split_section(target, dev): data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32) relax_check_gradients( @@ -350,7 +327,6 @@ def test_split_section(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_without_logits(target, dev): data_numpy1 = np.random.randint(1, 16, (3,)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3,)).astype(np.float32) @@ -364,7 +340,6 @@ def test_cross_entropy_without_logits(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_without_logits_batch(target, dev): data_numpy1 = np.random.randint(1, 16, (2, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (2, 3)).astype(np.float32) @@ -378,7 +353,6 @@ def test_cross_entropy_without_logits_batch(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_with_logits(target, dev): data_numpy1 = np.random.randint(1, 16, (3,)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3,)).astype(np.float32) @@ -392,7 +366,6 @@ def test_cross_entropy_with_logits(target, dev): ) -@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_with_logits_batch(target, dev): data_numpy1 = np.random.randint(1, 16, (2, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (2, 3)).astype(np.float32) diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 8e6a5974f7..ca895b889a 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -23,10 +23,10 @@ from tvm import relax from tvm import IRModule from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD +from tvm.relax.transform import LegalizeOps +from tvm.runtime.container import tuple_object from tvm.script.parser import relax as R from tvm.testing import assert_allclose -from tvm.runtime.container import tuple_object -from tvm.relax.transform import LegalizeOps def _legalize_and_build(mod: IRModule, target, dev): @@ -85,8 +85,13 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): _assert_allclose_nested(_tvm_to_numpy(opt.state), expected_state) -@tvm.testing.parametrize_targets("llvm") -def test_sgd(target, dev): +lr, weight_decay = tvm.testing.parameters( + (0.01, 0), + (0.01, 0.02), +) + + +def test_sgd(target, dev, lr, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new = [], [] @@ -97,14 +102,17 @@ def np_func(param_tuple, grad_tuple, state_tuple): param_tuple_new.append(param - lr * (grad + weight_decay * param)) return param_tuple_new, state_tuple_new - lr, weight_decay = 0.01, 0 - _test_optimizer(target, dev, np_func, SGD, lr) - lr, weight_decay = 0.01, 0.02 _test_optimizer(target, dev, np_func, SGD, lr, weight_decay) -@tvm.testing.parametrize_targets("llvm") -def test_momentum_sgd(target, dev): +lr, momentum, dampening, weight_decay, nesterov = tvm.testing.parameters( + (0.01, 0.9, 0, 0, False), + (0.01, 0.9, 0.85, 0.02, False), + (0.01, 0.9, 0.85, 0.02, True), +) + + +def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay, nesterov): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new = [], [] @@ -125,22 +133,18 @@ def np_func(param_tuple, grad_tuple, state_tuple): return param_tuple_new, state_tuple_new - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0, 0, False - _test_optimizer( - target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov - ) - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0.85, 0.02, False - _test_optimizer( - target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov - ) - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0.85, 0.02, True _test_optimizer( target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov ) -@tvm.testing.parametrize_targets("llvm") -def test_adam(target, dev): +lr, betas, eps, weight_decay = tvm.testing.parameters( + (0.01, (0.9, 0.999), 1e-08, 0), + (0.01, (0.8, 0.85), 1e-07, 0.1), +) + + +def test_adam(target, dev, lr, betas, eps, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] num_steps_new = num_steps + 1 @@ -168,9 +172,6 @@ def np_func(param_tuple, grad_tuple, state_tuple): return param_tuple_new, state_tuple_new - lr, betas, eps, weight_decay = 0.01, (0.9, 0.999), 1e-08, 0 - _test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay) - lr, betas, eps, weight_decay = 0.01, (0.8, 0.85), 1e-07, 0.1 _test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay) From 4b04acc15a4e76308c21a96c8cfd65cbea53b3f0 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Mon, 30 Jan 2023 08:55:20 +0800 Subject: [PATCH 15/17] fix test --- .../python/relax/test_op_gradient_numeric.py | 27 +++++++++++++++++++ .../relax/test_training_optimizer_numeric.py | 4 +++ 2 files changed, 31 insertions(+) diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 60e64803b8..6243b609cd 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -132,6 +132,7 @@ def forward(*inputs): check_numerical_grads(forward, inputs_numpy, _tvm_to_numpy(result)) +@tvm.testing.parametrize_targets("llvm") def test_add(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -140,6 +141,7 @@ def test_add(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_subtract(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -148,6 +150,7 @@ def test_subtract(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_multiply(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) @@ -156,6 +159,7 @@ def test_multiply(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_permute_dims(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -163,6 +167,7 @@ def test_permute_dims(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_permute_dims_with_axes(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -176,11 +181,13 @@ def test_permute_dims_with_axes(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_relu(target, dev): data1_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.nn.relu, "relax.nn.relu", [data1_numpy], target, dev, (3, 3)) +@tvm.testing.parametrize_targets("llvm") def test_matmul_2_2(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 4)).astype(np.float32) @@ -189,6 +196,7 @@ def test_matmul_2_2(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_matmul_1_1(target, dev): data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) @@ -197,6 +205,7 @@ def test_matmul_1_1(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_matmul_1_4(target, dev): data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) @@ -205,6 +214,7 @@ def test_matmul_1_4(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_matmul_4_1(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (5,)).astype(np.float32) @@ -213,6 +223,7 @@ def test_matmul_4_1(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_matmul_5_4(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 1, 4, 5)).astype(np.float32) data2_numpy = np.random.randint(0, 16, (3, 2, 5, 4)).astype(np.float32) @@ -226,6 +237,7 @@ def test_matmul_5_4(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_softmax(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -233,6 +245,7 @@ def test_softmax(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_softmax_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -240,6 +253,7 @@ def test_softmax_with_axis(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_log_softmax(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -247,6 +261,7 @@ def test_log_softmax(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_log_softmax_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -254,11 +269,13 @@ def test_log_softmax_with_axis(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_sum(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.sum, "relax.sum", [data1_numpy], target, dev, ()) +@tvm.testing.parametrize_targets("llvm") def test_sum_with_axis(target, dev): data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32) relax_check_gradients( @@ -266,6 +283,7 @@ def test_sum_with_axis(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_sum_keepdims(target, dev): data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32) relax_check_gradients( @@ -273,16 +291,19 @@ def test_sum_keepdims(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_sigmoid(target, dev): data_numpy = np.random.randint(1, 16, (3,)).astype(np.float32) relax_check_gradients(relax.op.sigmoid, "relax.sigmoid", [data_numpy], target, dev, (3,)) +@tvm.testing.parametrize_targets("llvm") def test_tanh(target, dev): data_numpy = np.random.randint(1, 16, (3, 3)).astype(np.float32) relax_check_gradients(relax.op.tanh, "relax.tanh", [data_numpy], target, dev, (3, 3)) +@tvm.testing.parametrize_targets("llvm") def test_concat(target, dev): data_numpy1 = np.random.randint(1, 16, (3, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3, 4)).astype(np.float32) @@ -299,6 +320,7 @@ def test_concat(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_split_indices(target, dev): data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32) relax_check_gradients( @@ -313,6 +335,7 @@ def test_split_indices(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_split_section(target, dev): data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32) relax_check_gradients( @@ -327,6 +350,7 @@ def test_split_section(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_without_logits(target, dev): data_numpy1 = np.random.randint(1, 16, (3,)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3,)).astype(np.float32) @@ -340,6 +364,7 @@ def test_cross_entropy_without_logits(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_without_logits_batch(target, dev): data_numpy1 = np.random.randint(1, 16, (2, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (2, 3)).astype(np.float32) @@ -353,6 +378,7 @@ def test_cross_entropy_without_logits_batch(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_with_logits(target, dev): data_numpy1 = np.random.randint(1, 16, (3,)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (3,)).astype(np.float32) @@ -366,6 +392,7 @@ def test_cross_entropy_with_logits(target, dev): ) +@tvm.testing.parametrize_targets("llvm") def test_cross_entropy_with_logits_batch(target, dev): data_numpy1 = np.random.randint(1, 16, (2, 3)).astype(np.float32) data_numpy2 = np.random.randint(1, 16, (2, 3)).astype(np.float32) diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index ca895b889a..3a57bc66bf 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -64,6 +64,7 @@ def _assert_run_result_same(tvm_func: Callable, np_func: Callable, np_inputs: Li _assert_allclose_nested(result, expected) +@tvm.testing.parametrize_targets("llvm") def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): x = relax.Var("x", R.Tensor((3, 3), "float32")) y = relax.Var("y", R.Tensor((3,), "float32")) @@ -91,6 +92,7 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): ) +@tvm.testing.parametrize_targets("llvm") def test_sgd(target, dev, lr, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] @@ -112,6 +114,7 @@ def np_func(param_tuple, grad_tuple, state_tuple): ) +@tvm.testing.parametrize_targets("llvm") def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay, nesterov): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] @@ -144,6 +147,7 @@ def np_func(param_tuple, grad_tuple, state_tuple): ) +@tvm.testing.parametrize_targets("llvm") def test_adam(target, dev, lr, betas, eps, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] From 55c9fb93c8b72f62d5a677e29bdad10bcf65e737 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Tue, 31 Jan 2023 03:09:32 +0800 Subject: [PATCH 16/17] check sinfo in append loss --- src/relax/training/utils.cc | 20 +++++++++++------ tests/python/relax/test_training_utils.py | 26 ++++++++++++++++++++--- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index deb8e03be2..940b957d47 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -95,21 +95,27 @@ class AppendLossMutator : public ExprMutator { Array RemapLossParams(const Array& loss_func_params, Array new_params) { for (int i = 0; i < static_cast(loss_func_params.size()); ++i) { Var loss_param = loss_func_params[i]; - if (i < static_cast(orig_rets.size())) { - // map return value to loss param + auto loss_param_sinfo = GetStructInfo(loss_param); + + if (i < static_cast(orig_rets.size())) { // map return value to loss param + auto orig_ret_sinfo = GetStructInfo(orig_rets[i]); + ICHECK(StructuralEqual()(orig_ret_sinfo, loss_param_sinfo)) + << "The struct info of the " << i + << "-th return value of orig func is: " << orig_ret_sinfo + << " while the corresponding struct info of parameter of loss function is " + << loss_param_sinfo << ", which is different."; + if (const auto* var_node = orig_rets[i].as()) { ICHECK(orig_rets[i].as()); orig_rets_var_.push_back(NullOpt); this->var_remap_[loss_param->vid] = GetRef(var_node); } else { - Var new_ret_var = - DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), GetStructInfo(orig_rets[i])); + Var new_ret_var = DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), orig_ret_sinfo); orig_rets_var_.push_back(new_ret_var); this->var_remap_[loss_param->vid] = new_ret_var; } - } else { - // append to the param list - Var new_loss_param = Var(loss_param->vid, GetStructInfo(loss_param), loss_param->span); + } else { // append to the param list + Var new_loss_param = Var(loss_param->vid, loss_param_sinfo, loss_param->span); this->var_remap_[loss_param->vid] = new_loss_param; new_params.push_back(new_loss_param); } diff --git a/tests/python/relax/test_training_utils.py b/tests/python/relax/test_training_utils.py index 6b2fcd88a6..3ba93b689f 100644 --- a/tests/python/relax/test_training_utils.py +++ b/tests/python/relax/test_training_utils.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm.testing -from tvm import relax +from tvm import relax, TVMError from tvm.ir.base import assert_structural_equal from tvm.script import relax as R @@ -76,7 +77,6 @@ def loss( def expected( x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32") ) -> R.Tensor(None, dtype="float32", ndim=2): - # block 0 with R.dataflow(): gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) @@ -101,7 +101,7 @@ def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float3 @R.function def loss( arg1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")), - arg2: R.Tensor((), dtype="float32"), + arg2: R.Tensor((3, 3), dtype="float32"), ): with R.dataflow(): arg10 = arg1[0] @@ -131,5 +131,25 @@ def expected( assert_structural_equal(after, expected) +def test_append_loss_wrong_struct_info(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.sum(y) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss(arg1: R.Tensor((), dtype="float64"), arg2: R.Tensor((), dtype="float64")): + with R.dataflow(): + gv0 = R.add(arg1, arg2) + R.output(gv0) + return gv0 + + with pytest.raises(TVMError): + after = relax.training.utils.append_loss(orig, loss) + + if __name__ == "__main__": tvm.testing.main() From 77781fc74fa81db7fa78d6b749d96f8b5a2830c3 Mon Sep 17 00:00:00 2001 From: SiriusNEO <1713833595@qq.com> Date: Tue, 31 Jan 2023 17:11:32 +0800 Subject: [PATCH 17/17] address comments --- python/tvm/relax/training/__init__.py | 4 +- python/tvm/relax/training/loss.py | 14 +++--- python/tvm/relax/training/utils.py | 61 +++++++++++++----------- src/relax/training/utils.cc | 41 ++++++++-------- src/relax/training/utils.h | 4 +- tests/python/relax/test_training_loss.py | 27 +++-------- tests/python/relax/test_utils.py | 4 +- 7 files changed, 71 insertions(+), 84 deletions(-) diff --git a/python/tvm/relax/training/__init__.py b/python/tvm/relax/training/__init__.py index c3bda65860..f75b0a8ecf 100644 --- a/python/tvm/relax/training/__init__.py +++ b/python/tvm/relax/training/__init__.py @@ -18,6 +18,4 @@ from . import optimizer from . import utils - -# loss functions -from .loss import L1Loss, MSELoss, CrossEntropyLoss +from . import loss diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 0e11e7958f..10f3c45428 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -39,7 +39,7 @@ def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: return Var(param.name_hint, param.struct_info) -class _Loss: +class Loss: r"""Base class of all loss. Parameters @@ -81,7 +81,7 @@ def _with_reduction(self, expr: Expr) -> Expr: return expr -class L1Loss(_Loss): +class L1Loss(Loss): r"""Mean element-wise absolute value difference. Parameters @@ -127,10 +127,10 @@ def __call__( loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) - return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) + return bb.get()[self._loss_name] -class MSELoss(_Loss): +class MSELoss(Loss): r"""Measures the element-wise mean squared error. Parameters @@ -177,10 +177,10 @@ def __call__( loss = bb.emit_output(self._with_reduction(lv)) bb.emit_func_output(loss) - return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) + return bb.get()[self._loss_name] -class CrossEntropyLoss(_Loss): +class CrossEntropyLoss(Loss): r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. Parameters @@ -248,4 +248,4 @@ def __call__( ) bb.emit_func_output(loss) - return bb.get()[self._loss_name].with_attr("global_symbol", self._loss_name) + return bb.get()[self._loss_name] diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index 8c947f3f48..b2057c9f95 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -28,42 +28,45 @@ def append_loss(orig_func: Function, loss_func: Function) -> Function: those arguments of loss_func which are not mapped to some return values, they will be lifted and appended to the argument list of result function. - Notice: + Note + ------- 1. This uitl is dedicated to loss functions, not for general purposes. 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in some sense. - Example: + Example + ------- + >>> @R.function + ... def orig(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32")): + ... with R.dataflow(): + ... out = R.add(x, y) + ... R.output(out) + ... return out + + >>> @R.function + ... def loss(predictions: R.Tensor((2, 4), "float32"), labels: R.Tensor((2, 4), "float32")): + ... with R.dataflow(): + ... lv = R.subtract(predictions, labels) + ... lv1 = R.multiply(lv, lv) + ... gv = R.sum(lv1) + ... R.output(gv) + ... return gv - .. code-block:: python - # Before. - @R.function - def orig(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32")): - with R.dataflow(): - out = R.add(x, y) - R.output(out) - return out + >>> expected = append_loss(orig, loss) + >>> print(expected) - @R.function - def loss(predictions: R.Tensor((2, 4), "float32"), labels: R.Tensor((2, 4), "float32")): - with R.dataflow(): - lv = R.subtract(predictions, labels) - lv1 = R.multiply(lv, lv) - gv = R.sum(lv1) - R.output(gv) - return gv + Will get - # After. - @R.function - def expected(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32"), - labels: R.Tensor((2, 4), "float32")) -> R.Tensor((), "float32"): - with R.dataflow(): - out: R.Tensor((2, 4), "float32") = R.add(x, y) - lv: R.Tensor((2, 4), "float32") = R.subtract(out, labels) - lv1: R.Tensor((2, 4), "float32") = R.multiply(lv, lv) - gv: R.Tensor((), "float32") = R.sum(lv1) - R.output(gv) - return gv + >>> @R.function + ... def expected(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32"), + ... labels: R.Tensor((2, 4), "float32")) -> R.Tensor((), "float32"): + ... with R.dataflow(): + ... out: R.Tensor((2, 4), "float32") = R.add(x, y) + ... lv: R.Tensor((2, 4), "float32") = R.subtract(out, labels) + ... lv1: R.Tensor((2, 4), "float32") = R.multiply(lv, lv) + ... gv: R.Tensor((), "float32") = R.sum(lv1) + ... R.output(gv) + ... return gv Parameters ---------- diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 940b957d47..e18e3c0ded 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -27,7 +27,7 @@ class AppendLossMutator : public ExprMutator { public: explicit AppendLossMutator(const SeqExpr& loss_body) : loss_body_(loss_body) {} - Expr VisitExpr_(const SeqExprNode* seq_expr) override { + Expr VisitExpr_(const SeqExprNode* seq_expr) final { // mutate only the last block. Array blocks; for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { @@ -45,7 +45,7 @@ class AppendLossMutator : public ExprMutator { return SeqExpr(blocks, loss_body_->body); } - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { builder_->BeginDataflowBlock(); // emit original bindings. for (const auto& binding : block->bindings) { @@ -71,23 +71,19 @@ class AppendLossMutator : public ExprMutator { return builder_->EndBlock(); } - void VisitBinding_(const VarBindingNode* binding) override { - Var new_var = Downcast(this->VisitExpr(binding->var)); - Expr new_value = this->VisitExpr(binding->value); - builder_->EmitNormalized(VarBinding(new_var, new_value)); - } + // Remap the def site var using VisitExpr. + Var VisitVarDef(const Var& var) final { return Downcast(this->VisitExpr(var)); } - // remap orignal dataflow var + // Remap original dataflow var. // TODO(chaofan): a better way to check whether new_ret_var should be dataflow void RemapToDataflow(SeqExpr body) { - for (const BindingBlock& block : body->blocks) { - for (const Binding& binding : block->bindings) { - const auto* binding_node = binding.as(); - if (binding_node && !binding_node->var->IsInstance()) { - Var new_binding_var = DataflowVar( - binding_node->var->vid, GetStructInfo(binding_node->var), binding_node->var->span); - this->var_remap_[binding_node->var->vid] = new_binding_var; - } + const auto& block = body->blocks.back(); + for (const Binding& binding : block->bindings) { + const auto* binding_node = binding.as(); + if (binding_node && !binding_node->var->IsInstance()) { + Var new_binding_var = DataflowVar(binding_node->var->vid, GetStructInfo(binding_node->var), + binding_node->var->span); + this->var_remap_[binding_node->var->vid] = new_binding_var; } } } @@ -99,7 +95,7 @@ class AppendLossMutator : public ExprMutator { if (i < static_cast(orig_rets.size())) { // map return value to loss param auto orig_ret_sinfo = GetStructInfo(orig_rets[i]); - ICHECK(StructuralEqual()(orig_ret_sinfo, loss_param_sinfo)) + ICHECK(structural_equal_(orig_ret_sinfo, loss_param_sinfo)) << "The struct info of the " << i << "-th return value of orig func is: " << orig_ret_sinfo << " while the corresponding struct info of parameter of loss function is " @@ -123,11 +119,16 @@ class AppendLossMutator : public ExprMutator { return new_params; } + /*! \brief The original unpacked rets. */ Array orig_rets; private: + /*! \brief The body of the loss function */ SeqExpr loss_body_; + /*! \brief The var created for original rets. NullOpt if the original ret is already a var. */ Array> orig_rets_var_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; }; /*! @@ -144,12 +145,12 @@ Function AppendLoss(Function orig_func, Function loss_func) { << "The body of the loss function is expected to be a SeqExpr, but got" << loss_func->body->GetTypeKey(); - auto param_copied_func = CopyWithNewParams(orig_func); + Function param_copied_func = CopyWithNewParams(orig_func); auto seq_expr = Downcast(param_copied_func->body); AppendLossMutator mutator(Downcast(loss_func->body)); mutator.RemapToDataflow(seq_expr); - // Get the orignal rets. If it is a Tuple, unpack it. + // Get the original rets. If it is a Tuple, unpack it. if (orig_func->ret_struct_info.as()) { const auto* tuple_node = seq_expr->body.as(); ICHECK(tuple_node != nullptr); @@ -165,7 +166,7 @@ Function AppendLoss(Function orig_func, Function loss_func) { "parameters of loss function. Got " << mutator.orig_rets.size() << " > " << loss_func->params.size(); - auto new_params = mutator.RemapLossParams(loss_func->params, param_copied_func->params); + Array new_params = mutator.RemapLossParams(loss_func->params, param_copied_func->params); Expr new_body = mutator.VisitExpr(seq_expr); return Function(std::move(new_params), std::move(new_body), loss_func->ret_struct_info, param_copied_func->attrs); diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index 82c5d4d461..d8008fe456 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -34,12 +34,10 @@ namespace relax { /*! * \brief Local helper to append a specified loss function after the original function. - * - * Notice: + * \note * 1. This uitl is dedicated to loss functions, not for general purposes. * 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in * some sense. - * * \param orig_func The function to be appended to. * \param loss_func The loss function. * \return The result function after appended. diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 6880e0517f..0a4a154013 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -38,13 +38,12 @@ def test_l1_loss(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N, C), "float32") - l1_loss = relax.training.L1Loss() + l1_loss = relax.training.loss.L1Loss() @R.function def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): - R.func_attr({"global_symbol": "l1_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.abs(lv) @@ -57,7 +56,7 @@ def expected( def test_l1_loss_append(): s = forward.ret_struct_info - l1_loss = relax.training.L1Loss(reduction="sum") + l1_loss = relax.training.loss.L1Loss(reduction="sum") forward_with_loss = relax.training.utils.append_loss(forward, l1_loss(s, s)) @R.function @@ -67,7 +66,6 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): - # block 0 with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -85,15 +83,12 @@ def test_mse_loss(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N, C), "float32") - mse_loss = relax.training.MSELoss() + mse_loss = relax.training.loss.MSELoss() @R.function def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): - # function attr dict - R.func_attr({"global_symbol": "mse_loss"}) - # block 0 with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv) @@ -106,7 +101,7 @@ def expected( def test_mse_loss_append(): s = forward.ret_struct_info - mse_loss = relax.training.MSELoss(reduction="sum") + mse_loss = relax.training.loss.MSELoss(reduction="sum") forward_with_loss = relax.training.utils.append_loss(forward, mse_loss(s, s)) @R.function @@ -116,7 +111,6 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): - # block 0 with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -135,7 +129,7 @@ def test_cross_entropy_loss(): predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N,), "int64") weights = relax.TensorStructInfo((C,), "float32") - cross_entropy_loss = relax.training.CrossEntropyLoss(reduction="sum", ignore_index=1) + cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) @R.function def expected( @@ -143,9 +137,6 @@ def expected( targets: R.Tensor((3,), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): - # function attr dict - R.func_attr({"global_symbol": "cross_entropy_loss"}) - # block 0 with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -162,15 +153,12 @@ def test_cross_entropy_loss_without_weights(): C = 5 predictions = relax.TensorStructInfo((N, C), "float32") targets = relax.TensorStructInfo((N,), "int64") - cross_entropy_loss = relax.training.CrossEntropyLoss() + cross_entropy_loss = relax.training.loss.CrossEntropyLoss() @R.function def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), "int64") ) -> R.Tensor((), "float32"): - # function attr dict - R.func_attr({"global_symbol": "cross_entropy_loss"}) - # block 0 with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -188,7 +176,7 @@ def test_cross_entropy_loss_append(): C = s.shape[1] targets = relax.TensorStructInfo((N,), "int64") weights = relax.TensorStructInfo((C,), "float32") - cross_entropy_loss = relax.training.CrossEntropyLoss(reduction="sum", ignore_index=1) + cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) forward_with_loss = relax.training.utils.append_loss( forward, cross_entropy_loss(s, targets, weights) ) @@ -201,7 +189,6 @@ def expected( targets: R.Tensor((2,), "int64"), weights: R.Tensor((4,), "float32"), ) -> R.Tensor((), "float32"): - # block 0 with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 61502d737d..1cf2b56fa9 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm.testing +import pytest from tvm import relax from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R @@ -35,4 +35,4 @@ def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): if __name__ == "__main__": - tvm.testing.main() + pytest.main([__file__])