diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index d6151cdc29..de9c7f75fc 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -27,6 +27,7 @@ from . import expr_functor from . import struct_info from . import utils +from . import training # Expr @@ -88,4 +89,4 @@ ) # Training utils -from .training import optimizer +from .training import loss, optimizer diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index de3cb31086..3ce49ab7be 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -634,7 +634,7 @@ def nll_loss( The weight of each target value. If not specified, it is treated as if having all ones. - reduction : string + reduction : str The reduction method to apply to the output. Possible values are "mean", "sum" and "none". diff --git a/python/tvm/relax/training/__init__.py b/python/tvm/relax/training/__init__.py index 2cf602cb4f..f75b0a8ecf 100644 --- a/python/tvm/relax/training/__init__.py +++ b/python/tvm/relax/training/__init__.py @@ -17,3 +17,5 @@ """The Relax training APIs.""" from . import optimizer +from . import utils +from . import loss diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py new file mode 100644 index 0000000000..70cb83fc0e --- /dev/null +++ b/python/tvm/relax/training/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.relax.training""" +import tvm._ffi + +tvm._ffi._init_api("relax.training", __name__) diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py new file mode 100644 index 0000000000..10f3c45428 --- /dev/null +++ b/python/tvm/relax/training/loss.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Loss functions library for relax.""" + +from typing import Optional, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +from ..block_builder import BlockBuilder +from ..expr import Expr, Var, Function, StructInfo + +from ..op import abs, sum, mean, subtract, multiply +from ..op.nn import log_softmax, nll_loss + + +def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: + if isinstance(param, StructInfo): + param = Var(param_name, param) + if not isinstance(param, Var): + raise TypeError("The type of param should be Var or StructInfo, but got " + type(param)) + return Var(param.name_hint, param.struct_info) + + +class Loss: + r"""Base class of all loss. + + Parameters + ---------- + loss_name : str + The name of the loss function. + + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. + """ + + _valid_reductions = ["mean", "sum", "none"] + + def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None: + self._loss_name = loss_name + self._reduction = reduction + + if self._reduction not in self._valid_reductions: + raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) + + def _with_reduction(self, expr: Expr) -> Expr: + """Add a reduction to the final loss. + + Parameters + ---------- + expr : Expr + The loss expr. + """ + if self._reduction == "sum": + expr = sum(expr) + elif self._reduction == "mean": + expr = mean(expr) + elif self._reduction != "none": + raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) + return expr + + +class L1Loss(Loss): + r"""Mean element-wise absolute value difference. + + Parameters + ---------- + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. + """ + + def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: + super().__init__("l1_loss", reduction) + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + ) -> Function: + """Get the relax function of L1Loss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + Returns + ---------- + The relax function of L1Loss with the loss name as its global symbol. + """ + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + with bb.function(self._loss_name, [predictions, targets]): + with bb.dataflow(): + lv = abs(subtract(predictions, targets)) + loss = bb.emit_output(self._with_reduction(lv)) + bb.emit_func_output(loss) + + return bb.get()[self._loss_name] + + +class MSELoss(Loss): + r"""Measures the element-wise mean squared error. + + Parameters + ---------- + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. + """ + + def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: + super().__init__("mse_loss", reduction) + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + ) -> Function: + """Get the relax function of MSELoss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + Returns + ---------- + The relax function of MSELoss with the loss name as its global symbol. + """ + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + with bb.function(self._loss_name, [predictions, targets]): + with bb.dataflow(): + lv = subtract(predictions, targets) + lv = multiply(lv, lv) + loss = bb.emit_output(self._with_reduction(lv)) + bb.emit_func_output(loss) + + return bb.get()[self._loss_name] + + +class CrossEntropyLoss(Loss): + r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. + + Parameters + ---------- + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. + + ignore_index : int + Specifies a target value that is ignored and does not contribute to the input gradient. + """ + + ignore_index: int + + def __init__( + self, + reduction: Literal["mean", "sum", "none"] = "mean", + ignore_index: int = -100, + ) -> None: + super().__init__("cross_entropy_loss", reduction) + self.ignore_index = ignore_index + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + weights: Optional[Union[Var, StructInfo]] = None, + ) -> Function: + """Get the relax function of CrossEntropyLoss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + weights : Optional[Union[Var, StructInfo]] + a manual rescaling weight given to each class. It has to be a Tensor of size C. + + Returns + ---------- + The relax function of CrossEntropyLoss with the loss name as its global symbol. + """ + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + arg_list = [predictions, targets] + if weights: + weights = _create_param_var(weights, "weights") + arg_list.append(weights) + + with bb.function(self._loss_name, arg_list): + with bb.dataflow(): + logits = bb.emit(log_softmax(predictions)) + loss = bb.emit_output( + nll_loss(logits, targets, weights, self._reduction, self.ignore_index) + ) + bb.emit_func_output(loss) + + return bb.get()[self._loss_name] diff --git a/python/tvm/relax/training/optimizer.py b/python/tvm/relax/training/optimizer.py index 729b861e2b..b5fd9365f7 100644 --- a/python/tvm/relax/training/optimizer.py +++ b/python/tvm/relax/training/optimizer.py @@ -22,12 +22,14 @@ import numpy as np # type: ignore import tvm -from tvm import relax as rx -from tvm.relax.struct_info import TensorStructInfo -from tvm.relax.transform.legalize_ops import LegalizeOps from tvm.runtime.container import tuple_object -from tvm.relax.op import add, subtract, multiply, divide, sqrt -from tvm.relax import Var, Function + +from ..vm import VirtualMachine, build +from ..block_builder import BlockBuilder +from ..struct_info import TensorStructInfo, TupleStructInfo +from ..transform.legalize_ops import LegalizeOps +from ..op import add, subtract, multiply, divide, sqrt +from ..expr import const, Var, Function, TupleGetItem, Tuple as RxTuple # TODO(chaofan, yixin): Migrate key logics to C++ class Optimizer: @@ -74,7 +76,7 @@ class Optimizer: _dtype: str # these attributes are for the building and running process of the optimizer function - _vm_module: rx.VirtualMachine + _vm_module: VirtualMachine _target: Union[str, tvm.target.Target] _device: Union[tvm.runtime.Device, List[tvm.runtime.Device]] @@ -248,8 +250,8 @@ def __call__( mod = tvm.IRModule({self.name: self.get_function()}) # pylint: disable=not-callable lowered_mod = LegalizeOps()(mod) # type: ignore - executable = rx.vm.build(lowered_mod, self._target) - self._vm_module = rx.VirtualMachine(executable, self._device) + executable = build(lowered_mod, self._target) + self._vm_module = VirtualMachine(executable, self._device) new_params, self.state = self._vm_module[self.name](params_adt, grads_adt, self.state) return new_params @@ -333,38 +335,38 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) - state_var = Var("optim_states", rx.TupleStructInfo([rx.TensorStructInfo((), "int64")])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) + state_var = Var("optim_states", TupleStructInfo([TensorStructInfo((), "int64")])) # constants - lr = rx.const(self.lr, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - one = rx.const(1, "int64") + lr = const(self.lr, dtype) + weight_decay = const(self.weight_decay, dtype) + one = const(1, "int64") - builder = rx.BlockBuilder() + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new, state_list_new = [], [] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one), "num_steps_new") state_list_new.append(num_steps_new) # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") p_new = builder.emit(subtract(p, multiply(lr, g)), name + "_new") param_list_new.append(p_new) # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] @@ -471,36 +473,36 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) state_var = Var( "optim_states", - rx.TupleStructInfo([rx.TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), + TupleStructInfo([TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), ) # constants - lr = rx.const(self.lr, dtype) - momentum = rx.const(self.momentum, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - dampening_inv = rx.const(_high_precision_subtract(1, self.dampening), dtype) - one = rx.const(1, "int64") + lr = const(self.lr, dtype) + momentum = const(self.momentum, dtype) + weight_decay = const(self.weight_decay, dtype) + dampening_inv = const(_high_precision_subtract(1, self.dampening), dtype) + one = const(1, "int64") - builder = rx.BlockBuilder() + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new, state_list_new = [], [] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one), "num_steps_new") state_list_new.append(num_steps_new) # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") - v = builder.emit(rx.TupleGetItem(state_var, i + 1), name + "_v") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") + v = builder.emit(TupleGetItem(state_var, i + 1), name + "_v") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") damp_g = multiply(dampening_inv, g) if self.dampening else g @@ -515,8 +517,8 @@ def get_function(self) -> Function: state_list_new.append(v_new) # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] @@ -638,15 +640,15 @@ def get_function(self) -> Function: dtype = self._dtype # input variables - param_var = Var("params", rx.TupleStructInfo([p.struct_info for p in plist])) - grad_var = Var("gradients", rx.TupleStructInfo([p.struct_info for p in plist])) + param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) + grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) state_var = Var( "optim_states", - rx.TupleStructInfo( + TupleStructInfo( [ - rx.TensorStructInfo((), "int64"), - rx.TensorStructInfo((), dtype), - rx.TensorStructInfo((), dtype), + TensorStructInfo((), "int64"), + TensorStructInfo((), dtype), + TensorStructInfo((), dtype), *(p.struct_info for p in plist), *(p.struct_info for p in plist), ] @@ -654,42 +656,38 @@ def get_function(self) -> Function: ) # constants - lr = rx.const(self.lr, dtype) - beta1 = rx.const(self.beta1, dtype) - beta2 = rx.const(self.beta2, dtype) - beta1_inv = rx.const(_high_precision_subtract(1, self.beta1), dtype) - beta2_inv = rx.const(_high_precision_subtract(1, self.beta2), dtype) - eps = rx.const(self.eps, dtype) - weight_decay = rx.const(self.weight_decay, dtype) - one_int = rx.const(1, "int64") - one_float = rx.const(1, dtype) - - builder = rx.BlockBuilder() + lr = const(self.lr, dtype) + beta1 = const(self.beta1, dtype) + beta2 = const(self.beta2, dtype) + beta1_inv = const(_high_precision_subtract(1, self.beta1), dtype) + beta2_inv = const(_high_precision_subtract(1, self.beta2), dtype) + eps = const(self.eps, dtype) + weight_decay = const(self.weight_decay, dtype) + one_int = const(1, "int64") + one_float = const(1, dtype) + + builder = BlockBuilder() with builder.function(self.name, [param_var, grad_var, state_var]): with builder.dataflow(): param_list_new = [] state_list_new = [None] * (len_param * 2 + 3) # type: List[Optional[Var]] # handle num_steps - num_steps = builder.emit(rx.TupleGetItem(state_var, 0), "num_steps") + num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") num_steps_new = builder.emit(add(num_steps, one_int), "num_steps_new") state_list_new[0] = num_steps_new - beta1_prod = builder.emit( - multiply(rx.TupleGetItem(state_var, 1), beta1), "beta1_prod" - ) - beta2_prod = builder.emit( - multiply(rx.TupleGetItem(state_var, 2), beta2), "beta2_prod" - ) + beta1_prod = builder.emit(multiply(TupleGetItem(state_var, 1), beta1), "beta1_prod") + beta2_prod = builder.emit(multiply(TupleGetItem(state_var, 2), beta2), "beta2_prod") state_list_new[1] = beta1_prod state_list_new[2] = beta2_prod # computation logics for i in range(len_param): name = self._param_list[i].name_hint - p = builder.emit(rx.TupleGetItem(param_var, i), name) - g = builder.emit(rx.TupleGetItem(grad_var, i), name + "_grad") - m = builder.emit(rx.TupleGetItem(state_var, i + 3), name + "_m") - v = builder.emit(rx.TupleGetItem(state_var, i + 3 + len_param), name + "_v") + p = builder.emit(TupleGetItem(param_var, i), name) + g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") + m = builder.emit(TupleGetItem(state_var, i + 3), name + "_m") + v = builder.emit(TupleGetItem(state_var, i + 3 + len_param), name + "_v") if self.weight_decay: g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") m_new = builder.emit( @@ -714,7 +712,7 @@ def get_function(self) -> Function: state_list_new[i + 3 + len_param] = v_new # handle return values - params_new = builder.emit_output(rx.Tuple(param_list_new), "params_new") - optim_states_new = builder.emit_output(rx.Tuple(state_list_new), "optim_states_new") + params_new = builder.emit_output(RxTuple(param_list_new), "params_new") + optim_states_new = builder.emit_output(RxTuple(state_list_new), "optim_states_new") builder.emit_func_output((params_new, optim_states_new)) return builder.get()[self.name] diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py new file mode 100644 index 0000000000..b2057c9f95 --- /dev/null +++ b/python/tvm/relax/training/utils.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for relax training.""" + +from ..expr import Function +from . import _ffi_api + + +def append_loss(orig_func: Function, loss_func: Function) -> Function: + """Local helper to append a specified loss function after the original function. + + In detail, the result function has the arguments list of orig_func and the combination + of their body, which passes the return values of orig_func as the arguments of loss_func. For + those arguments of loss_func which are not mapped to some return values, they will be lifted + and appended to the argument list of result function. + + Note + ------- + 1. This uitl is dedicated to loss functions, not for general purposes. + 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in + some sense. + + Example + ------- + >>> @R.function + ... def orig(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32")): + ... with R.dataflow(): + ... out = R.add(x, y) + ... R.output(out) + ... return out + + >>> @R.function + ... def loss(predictions: R.Tensor((2, 4), "float32"), labels: R.Tensor((2, 4), "float32")): + ... with R.dataflow(): + ... lv = R.subtract(predictions, labels) + ... lv1 = R.multiply(lv, lv) + ... gv = R.sum(lv1) + ... R.output(gv) + ... return gv + + >>> expected = append_loss(orig, loss) + >>> print(expected) + + Will get + + >>> @R.function + ... def expected(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), "float32"), + ... labels: R.Tensor((2, 4), "float32")) -> R.Tensor((), "float32"): + ... with R.dataflow(): + ... out: R.Tensor((2, 4), "float32") = R.add(x, y) + ... lv: R.Tensor((2, 4), "float32") = R.subtract(out, labels) + ... lv1: R.Tensor((2, 4), "float32") = R.multiply(lv, lv) + ... gv: R.Tensor((), "float32") = R.sum(lv1) + ... R.output(gv) + ... return gv + + Parameters + ---------- + orig_func : Function + The function to be appended to. + + loss_func : Function + The loss function. + + Returns + ------- + ret : Function + The result function. + """ + return _ffi_api.AppendLoss(orig_func, loss_func) # type: ignore diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc new file mode 100644 index 0000000000..e18e3c0ded --- /dev/null +++ b/src/relax/training/utils.cc @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "utils.h" + +namespace tvm { +namespace relax { + +/*! \brief Helper to implement append loss.*/ +class AppendLossMutator : public ExprMutator { + public: + explicit AppendLossMutator(const SeqExpr& loss_body) : loss_body_(loss_body) {} + + Expr VisitExpr_(const SeqExprNode* seq_expr) final { + // mutate only the last block. + Array blocks; + for (int i = 0; i < static_cast(seq_expr->blocks.size()); ++i) { + CHECK(seq_expr->blocks[i].as()) + << "All blocks in original functions should be Dataflow Block."; + if (i < static_cast(seq_expr->blocks.size()) - 1) { + blocks.push_back(seq_expr->blocks[i]); + } else { + BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[i]); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + } + } + return SeqExpr(blocks, loss_body_->body); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + builder_->BeginDataflowBlock(); + // emit original bindings. + for (const auto& binding : block->bindings) { + this->VisitBinding(binding); + } + + ICHECK(orig_rets_var_.size() == orig_rets.size()); + for (int i = 0; i < static_cast(orig_rets_var_.size()); ++i) { + if (orig_rets_var_[i].defined()) { + builder_->EmitNormalized(VarBinding(orig_rets_var_[i].value(), orig_rets[i])); + } + } + + // emit blocks for loss function part. + for (const BindingBlock& block : loss_body_->blocks) { + CHECK(block.as()) + << "All blocks in loss functions should be Dataflow Block."; + for (const Binding& binding : block->bindings) { + this->VisitBinding(binding); + } + } + + return builder_->EndBlock(); + } + + // Remap the def site var using VisitExpr. + Var VisitVarDef(const Var& var) final { return Downcast(this->VisitExpr(var)); } + + // Remap original dataflow var. + // TODO(chaofan): a better way to check whether new_ret_var should be dataflow + void RemapToDataflow(SeqExpr body) { + const auto& block = body->blocks.back(); + for (const Binding& binding : block->bindings) { + const auto* binding_node = binding.as(); + if (binding_node && !binding_node->var->IsInstance()) { + Var new_binding_var = DataflowVar(binding_node->var->vid, GetStructInfo(binding_node->var), + binding_node->var->span); + this->var_remap_[binding_node->var->vid] = new_binding_var; + } + } + } + + Array RemapLossParams(const Array& loss_func_params, Array new_params) { + for (int i = 0; i < static_cast(loss_func_params.size()); ++i) { + Var loss_param = loss_func_params[i]; + auto loss_param_sinfo = GetStructInfo(loss_param); + + if (i < static_cast(orig_rets.size())) { // map return value to loss param + auto orig_ret_sinfo = GetStructInfo(orig_rets[i]); + ICHECK(structural_equal_(orig_ret_sinfo, loss_param_sinfo)) + << "The struct info of the " << i + << "-th return value of orig func is: " << orig_ret_sinfo + << " while the corresponding struct info of parameter of loss function is " + << loss_param_sinfo << ", which is different."; + + if (const auto* var_node = orig_rets[i].as()) { + ICHECK(orig_rets[i].as()); + orig_rets_var_.push_back(NullOpt); + this->var_remap_[loss_param->vid] = GetRef(var_node); + } else { + Var new_ret_var = DataflowVar(/*name_hint=*/"ret_" + std::to_string(i), orig_ret_sinfo); + orig_rets_var_.push_back(new_ret_var); + this->var_remap_[loss_param->vid] = new_ret_var; + } + } else { // append to the param list + Var new_loss_param = Var(loss_param->vid, loss_param_sinfo, loss_param->span); + this->var_remap_[loss_param->vid] = new_loss_param; + new_params.push_back(new_loss_param); + } + } + return new_params; + } + + /*! \brief The original unpacked rets. */ + Array orig_rets; + + private: + /*! \brief The body of the loss function */ + SeqExpr loss_body_; + /*! \brief The var created for original rets. NullOpt if the original ret is already a var. */ + Array> orig_rets_var_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; +}; + +/*! + * \brief Local helper to append a specified loss function after the original function. + * \param orig_func The function to be appended to. + * \param loss_func The loss function. + * \return The result function after appended. + */ +Function AppendLoss(Function orig_func, Function loss_func) { + CHECK(orig_func->body->IsInstance()) + << "The body of the original function is expected to be a SeqExpr, but got" + << orig_func->body->GetTypeKey(); + CHECK(loss_func->body->IsInstance()) + << "The body of the loss function is expected to be a SeqExpr, but got" + << loss_func->body->GetTypeKey(); + + Function param_copied_func = CopyWithNewParams(orig_func); + auto seq_expr = Downcast(param_copied_func->body); + + AppendLossMutator mutator(Downcast(loss_func->body)); + mutator.RemapToDataflow(seq_expr); + // Get the original rets. If it is a Tuple, unpack it. + if (orig_func->ret_struct_info.as()) { + const auto* tuple_node = seq_expr->body.as(); + ICHECK(tuple_node != nullptr); + for (const Expr& field : tuple_node->fields) { + mutator.orig_rets.push_back(mutator.VisitExpr(field)); + } + } else { + mutator.orig_rets.push_back(mutator.VisitExpr(seq_expr->body)); + } + + CHECK(loss_func->params.size() >= mutator.orig_rets.size()) + << "The number of return values of original functions should be greater than the number of " + "parameters of loss function. Got " + << mutator.orig_rets.size() << " > " << loss_func->params.size(); + + Array new_params = mutator.RemapLossParams(loss_func->params, param_copied_func->params); + Expr new_body = mutator.VisitExpr(seq_expr); + return Function(std::move(new_params), std::move(new_body), loss_func->ret_struct_info, + param_copied_func->attrs); +} + +TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h new file mode 100644 index 0000000000..d8008fe456 --- /dev/null +++ b/src/relax/training/utils.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/training/utils.h + * \brief Utility classes and functions for relax training. + */ +#ifndef TVM_RELAX_TRAINING_UTILS_H_ +#define TVM_RELAX_TRAINING_UTILS_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Local helper to append a specified loss function after the original function. + * \note + * 1. This uitl is dedicated to loss functions, not for general purposes. + * 2. This util can be replaced if we have Inline pass. It is equivalent to inline a tail call in + * some sense. + * \param orig_func The function to be appended to. + * \param loss_func The loss function. + * \return The result function after appended. + */ +Function AppendLoss(Function orig_func, Function loss_func); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRAINING_UTILS_H_ diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py new file mode 100644 index 0000000000..0a4a154013 --- /dev/null +++ b/tests/python/relax/test_training_loss.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm.testing +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script import relax as R + + +@R.function +def forward( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), +) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w) + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + R.output(out) + return out + + +def test_l1_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "float32") + l1_loss = relax.training.loss.L1Loss() + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), "float32") = R.abs(lv) + gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, keepdims=False) + R.output(gv) + return gv + + assert_structural_equal(l1_loss(predictions, targets), expected) + + +def test_l1_loss_append(): + s = forward.ret_struct_info + l1_loss = relax.training.loss.L1Loss(reduction="sum") + forward_with_loss = relax.training.utils.append_loss(forward, l1_loss(s, s)) + + @R.function + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2, 4), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), "float32") = R.abs(lv1) + gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, keepdims=False) + R.output(gv) + return gv + + assert_structural_equal(forward_with_loss, expected) + + +def test_mse_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "float32") + mse_loss = relax.training.loss.MSELoss() + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) + lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv) + gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, keepdims=False) + R.output(gv) + return gv + + assert_structural_equal(mse_loss(predictions, targets), expected) + + +def test_mse_loss_append(): + s = forward.ret_struct_info + mse_loss = relax.training.loss.MSELoss(reduction="sum") + forward_with_loss = relax.training.utils.append_loss(forward, mse_loss(s, s)) + + @R.function + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2, 4), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets) + lv11: R.Tensor((2, 4), "float32") = R.multiply(lv1, lv1) + gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, keepdims=False) + R.output(gv) + return gv + + assert_structural_equal(forward_with_loss, expected) + + +def test_cross_entropy_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N,), "int64") + weights = relax.TensorStructInfo((C,), "float32") + cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), + targets: R.Tensor((3,), "int64"), + weights: R.Tensor((5,), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv, targets, weights, reduction="sum", ignore_index=1 + ) + R.output(gv) + return gv + + assert_structural_equal(cross_entropy_loss(predictions, targets, weights), expected) + + +def test_cross_entropy_loss_without_weights(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N,), "int64") + cross_entropy_loss = relax.training.loss.CrossEntropyLoss() + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), "int64") + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv, targets, reduction="mean", ignore_index=-100 + ) + R.output(gv) + return gv + + assert_structural_equal(cross_entropy_loss(predictions, targets), expected) + + +def test_cross_entropy_loss_append(): + s = forward.ret_struct_info + N = s.shape[0] + C = s.shape[1] + targets = relax.TensorStructInfo((N,), "int64") + weights = relax.TensorStructInfo((C,), "float32") + cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", ignore_index=1) + forward_with_loss = relax.training.utils.append_loss( + forward, cross_entropy_loss(s, targets, weights) + ) + + @R.function + def expected( + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((2, 4), "float32"), + targets: R.Tensor((2,), "int64"), + weights: R.Tensor((4,), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") + out: R.Tensor((2, 4), "float32") = R.add(lv, b) + lv1: R.Tensor((2, 4), "float32") = R.nn.log_softmax(out, axis=-1) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv1, targets, weights, reduction="sum", ignore_index=1 + ) + R.output(gv) + return gv + + assert_structural_equal(forward_with_loss, expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 8e6a5974f7..3a57bc66bf 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -23,10 +23,10 @@ from tvm import relax from tvm import IRModule from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD +from tvm.relax.transform import LegalizeOps +from tvm.runtime.container import tuple_object from tvm.script.parser import relax as R from tvm.testing import assert_allclose -from tvm.runtime.container import tuple_object -from tvm.relax.transform import LegalizeOps def _legalize_and_build(mod: IRModule, target, dev): @@ -64,6 +64,7 @@ def _assert_run_result_same(tvm_func: Callable, np_func: Callable, np_inputs: Li _assert_allclose_nested(result, expected) +@tvm.testing.parametrize_targets("llvm") def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): x = relax.Var("x", R.Tensor((3, 3), "float32")) y = relax.Var("y", R.Tensor((3,), "float32")) @@ -85,8 +86,14 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): _assert_allclose_nested(_tvm_to_numpy(opt.state), expected_state) +lr, weight_decay = tvm.testing.parameters( + (0.01, 0), + (0.01, 0.02), +) + + @tvm.testing.parametrize_targets("llvm") -def test_sgd(target, dev): +def test_sgd(target, dev, lr, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new = [], [] @@ -97,14 +104,18 @@ def np_func(param_tuple, grad_tuple, state_tuple): param_tuple_new.append(param - lr * (grad + weight_decay * param)) return param_tuple_new, state_tuple_new - lr, weight_decay = 0.01, 0 - _test_optimizer(target, dev, np_func, SGD, lr) - lr, weight_decay = 0.01, 0.02 _test_optimizer(target, dev, np_func, SGD, lr, weight_decay) +lr, momentum, dampening, weight_decay, nesterov = tvm.testing.parameters( + (0.01, 0.9, 0, 0, False), + (0.01, 0.9, 0.85, 0.02, False), + (0.01, 0.9, 0.85, 0.02, True), +) + + @tvm.testing.parametrize_targets("llvm") -def test_momentum_sgd(target, dev): +def test_momentum_sgd(target, dev, lr, momentum, dampening, weight_decay, nesterov): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] param_tuple_new, state_tuple_new = [], [] @@ -125,22 +136,19 @@ def np_func(param_tuple, grad_tuple, state_tuple): return param_tuple_new, state_tuple_new - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0, 0, False - _test_optimizer( - target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov - ) - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0.85, 0.02, False - _test_optimizer( - target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov - ) - lr, momentum, dampening, weight_decay, nesterov = 0.01, 0.9, 0.85, 0.02, True _test_optimizer( target, dev, np_func, MomentumSGD, lr, momentum, dampening, weight_decay, nesterov ) +lr, betas, eps, weight_decay = tvm.testing.parameters( + (0.01, (0.9, 0.999), 1e-08, 0), + (0.01, (0.8, 0.85), 1e-07, 0.1), +) + + @tvm.testing.parametrize_targets("llvm") -def test_adam(target, dev): +def test_adam(target, dev, lr, betas, eps, weight_decay): def np_func(param_tuple, grad_tuple, state_tuple): num_steps = state_tuple[0] num_steps_new = num_steps + 1 @@ -168,9 +176,6 @@ def np_func(param_tuple, grad_tuple, state_tuple): return param_tuple_new, state_tuple_new - lr, betas, eps, weight_decay = 0.01, (0.9, 0.999), 1e-08, 0 - _test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay) - lr, betas, eps, weight_decay = 0.01, (0.8, 0.85), 1e-07, 0.1 _test_optimizer(target, dev, np_func, Adam, lr, betas, eps, weight_decay) diff --git a/tests/python/relax/test_training_utils.py b/tests/python/relax/test_training_utils.py new file mode 100644 index 0000000000..3ba93b689f --- /dev/null +++ b/tests/python/relax/test_training_utils.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm.testing +from tvm import relax, TVMError +from tvm.ir.base import assert_structural_equal +from tvm.script import relax as R + + +def test_append_loss_basic_extend(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.sum(y) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss(arg1: R.Tensor((), dtype="float32"), arg2: R.Tensor((), dtype="float32")): + with R.dataflow(): + gv0 = R.add(arg1, arg2) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=0): + # block 0 + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv01: R.Tensor((), dtype="float32") = R.add(gv0, gv1) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +def test_append_loss_extra_params(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.add(x, x) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss( + arg1: R.Tensor((), dtype="float32"), + arg2: R.Tensor((3, 3), dtype="float32"), + arg3: R.Tensor((3, 3), dtype="float32"), + ): + with R.dataflow(): + gv0 = R.add(arg2, arg3) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), arg3: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=2): + with R.dataflow(): + gv0: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) + gv1: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv01: R.Tensor((3, 3), dtype="float32") = R.add(gv1, arg3) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +def test_append_loss_nested_tuple(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.add(x, x) + gv1 = R.sum(y) + gv2 = R.add(x, y) + R.output(gv0, gv1, gv2) + return (gv0, gv1), gv2 + + @R.function + def loss( + arg1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")), + arg2: R.Tensor((3, 3), dtype="float32"), + ): + with R.dataflow(): + arg10 = arg1[0] + gv0 = R.add(arg10, arg2) + R.output(gv0) + return gv0 + + @R.function + def expected( + x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32") + ) -> R.Tensor((3, 3), dtype="float32"): + # block 0 + with R.dataflow(): + gv0: R.Tensor((3, 3), dtype="float32") = R.add(x, x) + gv1: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False) + gv2: R.Tensor((3, 3), dtype="float32") = R.add(x, y) + ret_0: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((), dtype="float32")) = ( + gv0, + gv1, + ) + arg10: R.Tensor((3, 3), dtype="float32") = ret_0[0] + gv01: R.Tensor((3, 3), dtype="float32") = R.add(arg10, gv2) + R.output(gv01) + return gv01 + + after = relax.training.utils.append_loss(orig, loss) + assert_structural_equal(after, expected) + + +def test_append_loss_wrong_struct_info(): + @R.function + def orig(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + gv0 = R.sum(x) + gv1 = R.sum(y) + R.output(gv0, gv1) + return gv0, gv1 + + @R.function + def loss(arg1: R.Tensor((), dtype="float64"), arg2: R.Tensor((), dtype="float64")): + with R.dataflow(): + gv0 = R.add(arg1, arg2) + R.output(gv0) + return gv0 + + with pytest.raises(TVMError): + after = relax.training.utils.append_loss(orig, loss) + + +if __name__ == "__main__": + tvm.testing.main()