diff --git a/include/tvm/te/autodiff.h b/include/tvm/te/autodiff.h new file mode 100644 index 000000000000..180ec0bf676c --- /dev/null +++ b/include/tvm/te/autodiff.h @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/te/autodiff.h + * \brief Automatic differentiation of tensor expressions. + */ + +#ifndef TVM_TE_AUTODIFF_H_ +#define TVM_TE_AUTODIFF_H_ + +#include +#include +#include "tensor.h" + +namespace tvm { +/*! \brief Tensor expression language DSL. */ +namespace te { + +/*! + * \brief Take the derivative of the expression with respect to the given variable. + * \param expr The expression to differentiate. + * \param var The variable to differentiate with respect to. + * \return The expression for the derivative. + */ +PrimExpr Derivative(const PrimExpr& expr, const Var& var); + +/*! + * \brief Get the tensor representing the Jacobian of the output with respect to the input. + * + * Note that if \p output depends on \p input indirectly (by using some other tensor + * depending on \p input), this dependency won't contribute to the resulting Jacobian. + * For such cases use the function ::Gradient. + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \return The tensor representing the Jacobian of shape `output.shape + input.shape`. + */ +Tensor Jacobian(const Tensor& output, const Tensor& input); + +/*! + * \brief The building block for reverse-mode AD. + * + * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor + * dot product. \p input must be an immediate dependency of \p output (must be called from within + * the body of \p output). That is, the function will compute one summand of the adjoint for \p input + * given the adjoint for \p output (which is called \p head here). + * + * \param output The tensor to differentiate. + * \param input The input tensor, which \p output should directly use. + * \param head The adjoint of \p output. Must be of shape `prefix + output.shape` + * \return The tensor of shape `prefix + input.shape` + * representing the partial adjoint of \p input wrt one of its consumers (output) + */ +Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head); + +/*! + * \brief Perform reverse mode automatic differentiation. + * + * Each item of the `result` field of the result is an adjoint for the corresponding item of + * \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the + * corresponding item of \p inputs. + * + * \param output The tensor to differentiate. + * \param inputs The array of input tensors. When the array is empty, will perform differentiation + * wrt all tensors the output depends on. + * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians + * will be multiplied (using tensordot axes=`output.shape`). + * Its shape must be of the form `prefix + output.shape`. If the null pointer is provided, + * the identity tensor of shape `output.shape + output.shape` will be used. + * \return An array of adjoints corresponding to \p inputs. + */ +TVM_DLL Array Gradient( + const Tensor& output, + const Array& inputs, + const Tensor& head = Tensor()); + +} // namespace te +} // namespace tvm + +#endif // TVM_TE_AUTODIFF_H_ diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index e88c17a7b005..1ba554960ef7 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -33,3 +33,4 @@ from .operation import thread_axis, reduce_axis from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp +from .autodiff import gradient diff --git a/python/tvm/te/autodiff.py b/python/tvm/te/autodiff.py new file mode 100644 index 000000000000..f8650839948d --- /dev/null +++ b/python/tvm/te/autodiff.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Automatic differentiation of tensor expressions.""" +from . import _ffi_api + + +def gradient(output, inputs, head=None): + """Perform reverse-mode automatic differentiation. + + Parameters + ---------- + output : Tensor + The tensor to differentiate. + + inputs : List[Tensor] + The list of input tensors to be differentiated wrt. + + head : Tensor + The adjoint of the output, in other words, some tensor, by which the Jacobians + will be multiplied. Its shape must be of the form `prefix + output.shape`. + If `None` is passed, the identity tensor of shape `output.shape + output.shape` + will be used. + + Returns + ------- + tensors: List[Tensor] + The result gradient, in the same order as the inputs + + Example + ------- + .. code-block:: python + + x = tvm.placeholder((32, 3, 28, 28), name='x') + w1 = tvm.placeholder((10, 3, 3, 3), name='w1') + w2 = tvm.placeholder((10, 10, 3, 3), name='w2') + z1 = topi.nn.conv2d(x, w1, 1, 1, 1) + z2 = topi.nn.conv2d(z1, w2, 1, 1, 1) + y = topi.sum(z2) + + # produce gradients + [dw1, dw2] = tvm.gradient(y, [w1, w2]) + + # produce Jacobians + [jw1, jw2] = tvm.gradient(z2, [w1, w2]) + + # produce gradients, the head adjoint for z2 is provided manually + [dw1, dw2] = tvm.gradient(z2, [w1, w2], topi.full_like(z2, 1.0)) + + """ + if not isinstance(inputs, list): + inputs = [inputs] + return _ffi_api.Gradient(output, inputs, head) diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc new file mode 100644 index 000000000000..3a90beff4822 --- /dev/null +++ b/src/te/autodiff/ad_util.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_util.cc + * \brief Utility for tensor-level auto-differentiation. + */ +#include +#include +#include +#include "ad_util.h" + +namespace tvm { +namespace te { + +std::pair, Map> CloneIterVars(const Array& vars) { + Array new_vars; + Map vmap; + for (const IterVar& iv : vars) { + IterVar new_v = + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), + iv->iter_type, iv->thread_tag); + new_vars.push_back(new_v); + vmap.Set(iv->var, new_v->var); + } + return std::make_pair(std::move(new_vars), std::move(vmap)); +} + +PrimExpr CloneReduction(const PrimExpr& expr) { + if (const ReduceNode* red = expr.as()) { + Array new_axis; + Map vmap; + std::tie(new_axis, vmap) = CloneIterVars(red->axis); + + Array src_with_newaxis; + for (const auto& src : red->source) { + src_with_newaxis.push_back(tir::Substitute(src, vmap)); + } + + return ReduceNode::make(red->combiner, src_with_newaxis, + new_axis, tir::Substitute(red->condition, vmap), red->value_index); + } else { + return expr; + } +} + +} // namespace te +} // namespace tvm diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h new file mode 100644 index 000000000000..7e511b1c5a22 --- /dev/null +++ b/src/te/autodiff/ad_util.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_util.h + * \brief Helper utilities to implement auto-differentiation. + */ +#ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ +#define TVM_TE_AUTODIFF_AD_UTIL_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace te { + +/*! + * \brief Clone iter vars and return both the new vars and the substitution from old to new. + * + * \param vars The original iter vars. + * \return A pair containing the array of new iter vars and the map from old vars to new ones. + */ +std::pair, Map> CloneIterVars(const Array& vars); + +/*! + * \brief Clone reduction by cloning the axis variables. + * \param expr A reduction expr to clone. Non-reduction expressions are left intact. + */ +PrimExpr CloneReduction(const PrimExpr& expr); + +} // namespace te +} // namespace tvm +#endif // TVM_TE_AUTODIFF_AD_UTIL_H_ diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc new file mode 100644 index 000000000000..0c54764e601a --- /dev/null +++ b/src/te/autodiff/adjoint.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file adjoint.cc + * \brief Perform reverse-mode autodiff. + * Suppose we have f(x) = g(h1(x), h2(x), ..., hn(x)), + * df/dx = \sum_i df/dhi * dhi/dx + * We call df/dx as adjoint(x), df/dhi as adjoint(hi), dhi/dx is the Jacobian + * The idea is to first construct the reverse-dependency {input->outputs} between tensors, + * start from one input, + * (1) collect adjoints from all its dependencies (outputs), + * (2) multiply the Jacobian (PartialAdjoint), + * (3) and sum them together to get the adjoint of the input itself. + * The three steps are computed recursively. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace te { + +Tensor Identity(const Tensor& output) { + Array shape = output->shape; + for (auto e : output->shape) { + // add extra dimension for Jacobian + shape.push_back(e); + } + auto func = + [&output](const Array& input_indices) { + PrimExpr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = res && (PrimExpr(input_indices[i]) == + PrimExpr(input_indices[output->shape.size() + i])); + } + return CastNode::make(output->dtype, res); + }; + return te::compute(shape, func, "identity"); +} + +Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) { + Tensor jac = Jacobian(output, input); + Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), + output->op->name + "." + input->op->name + ".grad"); + return result; +} + +Array Gradient(const Tensor& output, + const Array& inputs, + const Tensor& head_or_null) { + // Diagonal identity tensor + Tensor head = head_or_null.get() ? head_or_null : Identity(output); + + // This Map{input -> outputs} maps a tensor to the list of tensors + // immediately depending on it (using it in their bodies) + std::unordered_map> reverse_dependencies; + std::vector stack({output}); + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + for (const Tensor& input : tensor->op->InputTensors()) { + if (!reverse_dependencies.count(input)) { + stack.push_back(input); + } + reverse_dependencies[input].push_back(tensor); + } + } + + // This map maps tensors to the corresponding adjoints (dLoss/dTensor) + std::unordered_map adjoints; + // head is the adjoint of output by definition + adjoints[output] = head; + + // This is a recursive function that does all the work. It computes the adjoint for a given + // tensor, adds it to the map, and returns it + std::function compute_adjoint; + compute_adjoint = + [&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output] + (const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector direct_consumers = reverse_dependencies[tensor]; + if (direct_consumers.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + // (i.e., output shape + tensor shape, aka shape of Jacobian) + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); + } + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); + } else { + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function VectorJacobianProduct + for (const Tensor& direct_consumer : direct_consumers) { + // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) + Tensor part = VectorJacobianProduct( + direct_consumer, tensor, compute_adjoint(direct_consumer)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + } + } + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; + + // Adjoints corresponding to inputs + Array result; + // Compute an adjoint for each input + for (const Tensor& input : inputs) { + result.push_back(compute_adjoint(input)); + } + + return result; +} + +TVM_REGISTER_GLOBAL("te.Gradient") +.set_body([](TVMArgs args, TVMRetValue *ret) { + LOG(WARNING) << "te.Gradient is an experimental feature."; + if (args.size() == 2) { + *ret = Gradient(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Gradient(args[0], args[1], args[2]); + } + }); + +} // namespace te +} // namespace tvm diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc new file mode 100644 index 000000000000..1a324588537f --- /dev/null +++ b/src/te/autodiff/jacobian.cc @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file jacobian.cc + * \brief Calculate Jacobian of two tensors dY/dX. + * X must be direct input tensor of Y. + * The result Jacobian shape will be (Y.shape, X.shape) + */ +#include +#include +#include +#include +#include +#include "ad_util.h" + +namespace tvm { +namespace te { + +#define NOT_IMPLEMENTED \ + { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); throw; } + +/*! \brief Differentiate an expression wrt a variable or a tensor element */ +class JacobianMutator : public ExprMutator { + public: + /*! + * \brief Differentiate wrt `input(indices)`. + * \param input The input tensor. + * \param indices The indices of the element with respect to which to differentiate. + */ + explicit JacobianMutator(Tensor input, Array indices) + : input_(input), indices_(indices) {} + /*! + * \brief Differentiate wrt the input variable. + * \param input The input variable. + */ + explicit JacobianMutator(Var input) : input_var_(input) {} + + PrimExpr Mutate(PrimExpr e) { + if (e.dtype().is_int() || e.dtype().is_uint()) { + LOG(WARNING) << "For now we assume that the derivative of any integer expression is always 0." + << " e = " << e; + return make_zero(e.dtype()); + } else { + return ExprMutator::VisitExpr(e); + } + } + + PrimExpr VisitExpr_(const VarNode* op) { + if (input_var_.get() && input_var_.get() == op && op->dtype.is_float()) { + return FloatImm(op->dtype, 1.0); + } else { + return make_zero(op->dtype); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const CallNode* op) { + PrimExpr expr = GetRef(op); + if (op->call_type == CallNode::CallType::Halide) { + if (input_.get() && op->func.same_as(input_->op) && + op->value_index == input_->value_index) { + // Tensor(indices) + CHECK_EQ(indices_.size(), op->args.size()); + PrimExpr condition = const_true(); + for (size_t i = 0; i < input_.ndim(); ++i) { + condition = AndNode::make(condition, EQNode::make(indices_[i], op->args[i])); + } + return CastNode::make(op->dtype, condition); + } else { + return make_zero(op->dtype); + } + } else if (op->call_type == CallNode::CallType::PureIntrinsic) { + static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; + if (op->name == "exp") { + return MulNode::make(Mutate(op->args[0]), expr); + } else if (op->name == "log") { + return DivNode::make(Mutate(op->args[0]), op->args[0]); + } else if (op->name == "sigmoid") { + return MulNode::make(Mutate(op->args[0]), + MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr))); + } else if (op->name == "sqrt") { + return DivNode::make(Mutate(op->args[0]), + MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); + } else if (op->name == "tanh") { + return MulNode::make(Mutate(op->args[0]), + SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr))); + } else if (op->name == "pow") { + auto x = op->args[0], y = op->args[1]; + return expr * (Mutate(y)*log(x) + Mutate(x)*y/x); + } else if (op->name == "fabs") { + auto type = op->args[0].dtype(); + return MulNode::make(Mutate(op->args[0]), + SelectNode::make(GENode::make(op->args[0], make_zero(type)), + FloatImm(type, 1.0), FloatImm(type, -1.0))); + } else if (op->name == intrinsic::tvm_if_then_else) { + Array new_args = {op->args[0], + Mutate(op->args[1]), + Mutate(op->args[2])}; + return CallNode::make(op->dtype, op->name, new_args, + op->call_type, op->func, op->value_index); + } else if (piecewise_const.count(op->name)) { + return FloatImm(expr.dtype(), 0.0); + } else { + throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); + } + } + NOT_IMPLEMENTED + } + + PrimExpr VisitExpr_(const AddNode* op) { + return AddNode::make(Mutate(op->a), Mutate(op->b)); + } + + PrimExpr VisitExpr_(const SubNode* op) { + return SubNode::make(Mutate(op->a), Mutate(op->b)); + } + + PrimExpr VisitExpr_(const MulNode* op) { + return AddNode::make( + MulNode::make(Mutate(op->a), op->b), + MulNode::make(op->a, Mutate(op->b))); + } + + PrimExpr VisitExpr_(const DivNode* op) { + return DivNode::make( + SubNode::make( + MulNode::make(Mutate(op->a), op->b), + MulNode::make(op->a, Mutate(op->b))), + MulNode::make(op->b, op->b)); + } + + PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const FloorDivNode* op) { + return FloorDivNode::make( + SubNode::make( + MulNode::make(Mutate(op->a), op->b), + MulNode::make(op->a, Mutate(op->b))), + MulNode::make(op->b, op->b)); + } + + PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const MinNode* op) { + return SelectNode::make(LENode::make(op->a, op->b), + Mutate(op->a), Mutate(op->b)); + } + + PrimExpr VisitExpr_(const MaxNode* op) { + return SelectNode::make(GENode::make(op->a, op->b), + Mutate(op->a), Mutate(op->b)); + } + + PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const ReduceNode* op) { + // This case is relatively difficult because a reduction expression + // may use an arbitrary combiner. + // The resulting reduction expression will return a tuple containing + // both derivatives and the original results (in exactly this order). + // The order matters when original init value is different from its derivative init value, + // and they depend on each other during gradient calculation, + // we must calculate derivatives first (using origin's init value), + // switching the order (original results first, then derivatives) + // makes the origin value be replaced before using, + // produces incorrect results. + + // Example of a ReduceNode, + // reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), + // source=[A(k)], axis=[iter_var(k, range(min=0, ext=5))], where=(bool)1, value_index=0) + + // We have to clone the reduction axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + PrimExpr expr_with_new_axes = te::CloneReduction(GetRef(op)); + const ReduceNode* new_op = expr_with_new_axes.as(); + + // New lhs and rhs variables of the new combiner consist of + // variables representing derivatives (which are later derived from new_op->source) + // followed by the original variables. + Array new_lhs; + for (const auto& var : new_op->combiner->lhs) { + new_lhs.push_back(var.copy_with_suffix(".jac")); + } + for (const auto& var : new_op->combiner->lhs) { + new_lhs.push_back(var); + } + + Array new_rhs; + for (const auto& var : new_op->combiner->rhs) { + new_rhs.push_back(var.copy_with_suffix(".jac")); + } + for (const auto& var : new_op->combiner->rhs) { + new_rhs.push_back(var); + } + + // The new combiner result also consists of the resulting derivatives + // followed by the original results. + Array new_result; + for (const auto& res : new_op->combiner->result) { + // Each resulting derivative is computed as a sum of derivatives + // wrt lhs and rhs multiplied by the derivatives of lhs and rhs + PrimExpr new_res = make_zero(res.dtype()); + for (size_t i = 0; i < new_op->combiner->lhs.size(); ++i) { + PrimExpr res_di = Derivative(res, new_op->combiner->lhs[i]); + // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor) + new_res = AddNode::make(new_res, MulNode::make(new_lhs[i], res_di)); + } + for (size_t i = 0; i < new_op->combiner->rhs.size(); ++i) { + PrimExpr res_di = Derivative(res, new_op->combiner->rhs[i]); + // new_rhs[i] is the derivative of rhs[i] (wrt our input tensor) + new_res = AddNode::make(new_res, MulNode::make(new_rhs[i], res_di)); + } + new_result.push_back(new_res); + } + // add original results + for (const auto& res : new_op->combiner->result) { + new_result.push_back(res); + } + + // The identity is transformed in a similar way + Array new_identity; + for (const auto& id : new_op->combiner->identity_element) { + new_identity.push_back(Mutate(id)); + } + for (const auto& id : new_op->combiner->identity_element) { + new_identity.push_back(id); + } + + // Same as source + Array new_source; + for (const auto& src : new_op->source) { + new_source.push_back(Mutate(src)); + } + for (const auto& src : new_op->source) { + new_source.push_back(src); + } + + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + // Also simplify the resulting combiner + // (mostly to get rid of unused components, e.g., the original expressions) + return Simplify( + ReduceNode::make(new_combiner, new_source, new_op->axis, + new_op->condition, new_op->value_index)); + } + + PrimExpr VisitExpr_(const CastNode* op) { + if (op->dtype.is_float()) { + return CastNode::make(op->dtype, Mutate(op->value)); + } else { + return make_zero(op->dtype); + } + } + + PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const SelectNode* op) { + return SelectNode::make(op->condition, + Mutate(op->true_value), Mutate(op->false_value)); + } + + PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED + + PrimExpr VisitExpr_(const IntImmNode* op) { + return IntImm(op->dtype, 0); + } + + PrimExpr VisitExpr_(const FloatImmNode* op) { + return FloatImm(op->dtype, 0); + } + + PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED + + private: + Tensor input_; + Array indices_; + Var input_var_; +}; + +PrimExpr Derivative(const PrimExpr& expr, const Var& var) { + return JacobianMutator(var).Mutate(expr); +} + +PrimExpr Jacobian(const PrimExpr& expr, const Tensor& input, const Array& indices) { + return JacobianMutator(input, indices).Mutate(expr); +} + +Tensor Jacobian(const Tensor& output, const Tensor& input) { + const ComputeOpNode* op = output->op.as(); + CHECK(op) << "Derivative of this operation is not implemented: " << output->op; + bool is_input_tensor = false; + for (const Tensor& child : op->InputTensors()) { + if (input == child) { + is_input_tensor = true; + break; + } + } + CHECK(is_input_tensor) << "Jacobian is called on a pair of tensors such that the output " + << "does not directly depend on the input."; + + // We have to clone the iteration axes because otherwise the original expression + // cannot be used together with the derivative (it will lead to errors during lowering) + Array new_axis; + Map vmap; + std::tie(new_axis, vmap) = te::CloneIterVars(op->axis); + + Array input_indices; + size_t i = 0; + for (PrimExpr ext : input->shape) { + IterVar new_v = IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), + IterVarType::kDataPar); + // Append jacobian iter to new_axis + new_axis.push_back(new_v); + // Differentiate wrt input[input_indices] + input_indices.push_back(new_v); + } + + // Compute Jacobian + PrimExpr new_body = Jacobian( + Substitute(op->body[output->value_index], vmap), input, input_indices); + new_body = Simplify(new_body); + + int value_index = 0; + Array new_bodies; + + // If this is a reduction then it may return a tuple and we have + // to repeat the body several times + if (const ReduceNode* red = new_body.as()) { + value_index = red->value_index; + for (size_t idx = 0; idx < red->source.size(); ++idx) { + new_bodies.push_back( + ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); + } + } else { + new_bodies.push_back(new_body); + } + + auto new_op = ComputeOpNode::make( + op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + + // Jacobian shape = output.shape + input.shape + Array new_shape = output->shape; + for (const auto& e : input->shape) { + new_shape.push_back(e); + } + + return TensorNode::make(new_shape, output->dtype, new_op, value_index); +} + +} // namespace te +} // namespace tvm diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 3e07532c18f4..08db74f0d9a5 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -24,7 +24,6 @@ #ifndef TVM_TE_OPERATION_COMPUTE_OP_H_ #define TVM_TE_OPERATION_COMPUTE_OP_H_ -#include #include #include #include diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py new file mode 100644 index 000000000000..c756de050b08 --- /dev/null +++ b/tests/python/unittest/test_te_autodiff.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.testing import check_numerical_grads, assert_allclose +import topi +from topi.util import get_const_tuple + +import numpy as np + + +def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None): + inputs = inputs if isinstance(inputs, list) else [inputs] + + def check_device(device, host="llvm"): + ctx = tvm.context(device, 0) + if not tvm.runtime.enabled(host): + return + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + sout = te.create_schedule(out.op) + mout = tvm.build(sout, [out] + inputs) + out_shape = get_const_tuple(out.shape) + + l, h = data_range + input_data = [tvm.nd.array( + np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype)) + for input in inputs] + + ones = topi.full_like(out, 1.0) + # we provide head to sum and reduce the output dimension, + # which equals to grad(out.sum(), inputs) + grads = te.gradient(out, inputs, head=ones) + grad_sched = te.create_schedule([grad.op for grad in grads]) + mgrad = tvm.build(grad_sched, list(grads) + inputs) + # print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True)) + + grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype) + for i, g in zip(inputs, grads)] + + mgrad(*grad_data, *input_data) + g_res = [g.asnumpy() for g in grad_data] + + if desired_grads: + assert isinstance(desired_grads, list) + for actual, desired in zip(g_res, desired_grads): + assert_allclose(actual, desired, rtol=0.1, atol=1e-2) + else: + def forward(*in_data): + out_data = tvm.nd.empty(out_shape, out.dtype) + mout(out_data, *[tvm.nd.array(d) for d in list(in_data)]) + return out_data.asnumpy().sum() + check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res) + + check_device("cpu") + + +def test_basic_operation(): + np.random.seed(0) + shape = (10, 10) + x = te.var("x", dtype='float32') + k = te.reduce_axis((0, 10), name="k") + l = te.reduce_axis((0, 10), name="l") + A0 = te.placeholder(shape, name='A0') + A1 = te.placeholder(shape, name='A1') + zeros = np.zeros(shape) + + B = te.compute(shape, lambda i, j: A0[i, j], name='B') + check_grad(B, [A0]) + + B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name='B') + check_grad(B, [A0, A1]) + + B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B') + check_grad(B, A0, desired_grads=[zeros]) + + B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B') + check_grad(B, A0, desired_grads=[zeros]) + + B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B') + check_grad(B, A0, desired_grads=[zeros]) + + B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B') + check_grad(B, A0, desired_grads=[zeros]) + + B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.sqrt(A0[i, j]*A0[i, j]*A0[j, i]), name='B') + check_grad(B, A0, data_range=(0.1, 10)) + + B = te.compute(shape, lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]), name='B') + check_grad(B, A0, data_range=(-4, 4)) + + B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name='B') + check_grad(B, A0) + + B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: te.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B') + check_grad(B, A0) + + B = te.compute(shape, lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B') + check_grad(B, [A0, A1]) + + B = te.compute(shape, lambda i, j: te.sum(A0[k, k] - + A0[te.min(j + k, 9), j]*A0[i, k], + axis=k), name='B') + check_grad(B, A0) + + def fcombine(x, y): + return x*y + + def fidentity(t0): + return tvm.tir.const(1, t0) + + prod = te.comm_reducer(fcombine, fidentity, name='prod') + B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B') + check_grad(B, A0) + + X = te.placeholder((10,), name='X') + A = te.compute((10,), lambda i: X[i] + X[9 - i]) + B = te.compute((10,), lambda i: X[i] * X[9 - i]) + Y = topi.tensordot(A, B, 1) + check_grad(Y, X) + + +def test_conv2d(): + np.random.seed(0) + X = te.placeholder((1, 2, 4, 4), name='X') + W = te.placeholder((5, 2, 3, 3), name='W') + + R = topi.nn.conv2d(X, W, 1, 1, 1) + check_grad(R, [X, W]) + + +if __name__ == "__main__": + test_basic_operation() + test_conv2d() diff --git a/tests/python/unittest/test_testing.py b/tests/python/unittest/test_testing.py index cfa13845e4e8..ea1680111ee0 100644 --- a/tests/python/unittest/test_testing.py +++ b/tests/python/unittest/test_testing.py @@ -34,6 +34,8 @@ def test_check_numerical_grads(): lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))), ] + np.random.seed(0) + # Avoid values too close to 0 since singularities of our functions are there min_x = 0.5