diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index fd21db5a9c147..2e0d5d7934289 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Options for the operators used to annotate a compiler. + */ +struct CompilerAttrs : public tvm::AttrsNode { + /*! \brief The 3rd party compiler for code generation. */ + std::string compiler; + + TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") { + TVM_ATTR_FIELD(compiler) + .describe("The 3rd compiler used for code generation."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 741e8b4788286..fe7a7a0418f6d 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -122,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc< * operator with other expressions. This function will be invoked * in AlterOpLayout pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. + * \param args The input symbols of the original node. * \param tinfos An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. @@ -136,8 +137,8 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. - * \param tinfos An array of placeholders, use for getting the inferred shape + * \param args The input symbols of the original node. + * \param arg_types An array of placeholders, use for getting the inferred shape * and dtype of the inputs. * \return new_expr The modified expression. */ @@ -146,6 +147,22 @@ using FTVMLegalize = runtime::TypedPackedFunc< const Array& args, const Array& arg_types)>; +/*! + * \brief Annotates an expression to indicate which compiler an op + * should be used for codegen. + * + * \param attrs The attribute of the original expr. + * \param args The arguments of the original expr. + * \param compiler The compiler that is used to compile the op. + * + * \return true if this op should be registered to invoke a specific compiler + * for codegen, otherwise, false. + */ +using FTVMAnnotateCompiler = runtime::TypedPackedFunc< + bool(const Attrs& attrs, // NOLINT(*) + const Array& args, + const std::string& compiler)>; + /*! * \brief Forward rewriting rule for a specific op. * diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ddadbe4fc31db..92eb99f2cd94f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -576,6 +576,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); */ TVM_DLL Pass PrintIR(bool show_meta_data = true); +/*! + * \brief Partition a Relay program into regions that can be executed on + * different backends. + * + * \return The pass. + */ +TVM_DLL Pass PartitionGraph(); + } // namespace transform /*! diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c7cbcf096a6cf..7901dc4f5074a 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,7 +29,7 @@ from . import adt from . import analysis from . import transform -from .build_module import build, create_executor, optimize +from .build_module import build, create_executor, optimize, build_extern_compiler from .transform import build_config from . import prelude from . import parser diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 28ce16b9b4523..7775b7ca4c21a 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -30,6 +30,7 @@ from .module import Module as _Module from .backend import interpreter as _interpreter from .backend.vm import VMExecutor +from . import transform as _transform def _update_target(target): target = target if target else _target.current_target() @@ -296,6 +297,34 @@ def optimize(mod, target=None, params=None): return mod, params +def build_extern_compiler(mod, compiler): + """Helper function that annotates a Relay module and patitions the + expression init into various regions. These regions will be handled + by either default compilers in TVM stack or the provided external compiler. + + Parameters + ---------- + mod : relay.Module + The module to build. Using relay.Function is deprecated. + + compiler : str + The name of the external compiler. + + Returns + ------- + mod : relay.Module + The relay module contains partitioned program regions (e.g. functions) + that will be compiled using different compilers. + """ + if isinstance(mod, _expr.Function): + mod = _Module.from_expr(mod) + + seq = _transform.Sequential([_transform.AnnotateCompiler(compiler), + _transform.PartitionGraph()]) + mod = seq(mod) + return mod + + class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a089cab669c92..702573ddeb0d1 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -19,7 +19,7 @@ # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - schedule_injective, Op, OpPattern, debug + register_annotate_compiler, schedule_injective, Op, OpPattern, debug # Operators from .reduce import * diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 2b9d4bcd81bc3..93639251beab0 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -62,6 +62,7 @@ def stop_fusion(data): """ return _make.stop_fusion(data) + def checkpoint(data): """Annotate an expression to be a checkpoint for the checkpointing memory optimization. @@ -78,3 +79,43 @@ def checkpoint(data): return _make.checkpoint(data) register_schedule("annotation.checkpoint", schedule_injective) + + +def compiler_begin(data, compiler): + """Annotate an expression to indicate that it is the beginning of + a regeion that will be handled by the given compiler. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of the annotated region. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.compiler_begin(data, compiler) + + +def compiler_end(data, compiler): + """Annotate an expression to indicate that it is the end of a region that + is handled by the provided compiler. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + compiler : Str + The compiler used to generate code of the annotated region. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.compiler_end(data, compiler) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 3159006486b33..b7d6d92b9edd7 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -18,4 +18,5 @@ """Neural network related operators.""" from __future__ import absolute_import as _abs from .contrib import * +from .annotate_compiler import * from . import _contrib diff --git a/python/tvm/relay/op/contrib/annotate_compiler.py b/python/tvm/relay/op/contrib/annotate_compiler.py new file mode 100644 index 0000000000000..4d1eeaeb01cfe --- /dev/null +++ b/python/tvm/relay/op/contrib/annotate_compiler.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +""" +External compiler related feature registration. + +It implements dispatchers that check if an operator should use a given compiler +to generate code. + +Each compiler can customize the support of an operator. For example, they can +check the attribute of the operator and/or the features of the input arguments +to decide if we should use the compiler for codegen. +""" +from __future__ import absolute_import + +import logging +import pkgutil +from pathlib import Path +from importlib import import_module + +from .. import op as reg + +logger = logging.getLogger('AnnotateCompiler') + +# Load available contrib compilers +compilers = {} +for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]): + compilers[name] = import_module( + '.%s' % name, package='.'.join(__name__.split('.')[:-1])) + + +def get_annotate_compiler(compiler, op_name): + """Get the annotate_compiler function from the registered compilers. + + Parameters + ---------- + compiler : Str + The name of a compiler that is used to generate code. + + op_name : Str + The name of an operator. + + Returns + ------- + ret : bool + If the operator uses the provided compiler for codegen. + """ + if compiler in compilers: + if hasattr(compilers[compiler], 'annotate_compiler'): + annotate_compiler = getattr(compilers[compiler], 'annotate_compiler') + if hasattr(annotate_compiler, op_name): + return getattr(annotate_compiler, op_name) + + logger.warning("%s in %s is not registered. Fallback to CPU", op_name, + compiler) + return lambda x, y: False + + +@reg.register_annotate_compiler("nn.conv2d") +def annotate_conv2d(attrs, args, compiler): + """Check if the provided compiler should be used for conv2d. + """ + return get_annotate_compiler(compiler, 'conv2d')(attrs, args) + + +@reg.register_annotate_compiler("nn.dense") +def annotate_dense(attrs, args, compiler): + """Check if the provided compiler should be used for dense. + """ + return get_annotate_compiler(compiler, 'dense')(attrs, args) + + +@reg.register_annotate_compiler("nn.relu") +def annotate_relu(attrs, args, compiler): + """Check if the provided compiler should be used for relu. + """ + return get_annotate_compiler(compiler, 'relu')(attrs, args) + + +@reg.register_annotate_compiler("nn.batch_norm") +def annotate_batch_norm(attrs, args, compiler): + """Check if the provided compiler should be used for batch_norm. + """ + return get_annotate_compiler(compiler, 'batch_norm')(attrs, args) + + +@reg.register_annotate_compiler("subtract") +def annotate_subtract(attrs, args, compiler): + """Check if the provided compiler should be used for subtract. + """ + return get_annotate_compiler(compiler, 'subtract')(attrs, args) + + +@reg.register_annotate_compiler("add") +def annotate_add(attrs, args, compiler): + """Check if the provided compiler should be used for add. + """ + return get_annotate_compiler(compiler, 'add')(attrs, args) + + +@reg.register_annotate_compiler("multiply") +def annotate_multiply(attrs, args, compiler): + """Check if the provided compiler should be used for multiply. + """ + return get_annotate_compiler(compiler, 'multiply')(attrs, args) diff --git a/python/tvm/relay/op/contrib/ccompiler/__init__.py b/python/tvm/relay/op/contrib/ccompiler/__init__.py new file mode 100644 index 0000000000000..57bf0595556ae --- /dev/null +++ b/python/tvm/relay/op/contrib/ccompiler/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .annotate_compiler import * diff --git a/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py b/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py new file mode 100644 index 0000000000000..7a5f88f1fc75c --- /dev/null +++ b/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""C/C++ compiler supported operators.""" +from __future__ import absolute_import + +def conv2d(attrs, args): + """Check if the external C source codegen should be used. + """ + return False + +def subtract(attrs, args): + """Check if the external C source codegen should be used. + """ + return True + +def add(attrs, args): + """Check if the external C source codegen should be used. + """ + return True + +def multiply(attrs, args): + """Check if the external C source codegen should be used. + """ + return True diff --git a/python/tvm/relay/op/contrib/dnnl/__init__.py b/python/tvm/relay/op/contrib/dnnl/__init__.py new file mode 100644 index 0000000000000..57bf0595556ae --- /dev/null +++ b/python/tvm/relay/op/contrib/dnnl/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .annotate_compiler import * diff --git a/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py b/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py new file mode 100644 index 0000000000000..b527395538d82 --- /dev/null +++ b/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""DNNL library supported operators.""" +from __future__ import absolute_import + + +def conv2d(attrs, args): + """Check if the external DNNL codegen should be used. + """ + return True + + +def dense(attrs, args): + """Check if the external DNNL codegen should be used. + """ + return True + + +def relu(attrs, args): + """Check if the external DNNL codegen should be used. + """ + return True + + +def batch_norm(attrs, args): + """Check if the external DNNL codegen should be used. + FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs. + """ + return False + + +def add(attrs, args): + """Check if the external DNNL codegen should be used. + """ + return True diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 355496e42b489..98e44c1f061ee 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -229,6 +229,7 @@ def register_pattern(op_name, pattern, level=10): """ return register(op_name, "TOpPattern", pattern, level) + def register_gradient(op_name, fgradient=None, level=10): """Register operator pattern for an op. @@ -266,6 +267,25 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): get(op_name).set_attr("TShapeDataDependant", data_dependant, level) return register(op_name, "FShapeFunc", shape_func, level) +def register_annotate_compiler(op_name, fannotate=None, level=10): + """Register the compiler for an op. + + Parameters + ---------- + op_name : str + The name of the operator. + + fannotate : function (attrs: Attrs, args: List[Expr], compiler: str) + -> new_expr: Expr + The function for wrapping a call expr with compiler_begin and + compiler_end. + + level : int + The priority level + """ + return register(op_name, "FTVMAnnotateCompiler", fannotate, level) + + _init_api("relay.op", __name__) @register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 540c1f5b79cd9..6fc480151876d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -480,6 +480,24 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): return _transform.Legalize(legalize_map_attr_name) +def AnnotateCompiler(compiler): + """Annotate ops in an experession with a provied compiler and then use it + for codegen. + + Parameters + ---------- + compiler : str + The compiler used for codegen. + + Returns + ------- + ret : tvm.relay.Pass + The annotated pass that wrapps ops with subgraph_start and + subgraph_end. + """ + return _transform.AnnotateCompiler(compiler) + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. @@ -635,6 +653,18 @@ def PrintIR(show_meta_data=True): return _transform.PrintIR(show_meta_data) +def PartitionGraph(): + """Partition a Relay program into regions that can be executed on different + backends. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that partitions the Relay program. + """ + return _transform.PartitionGraph() + + def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index f5674fa06adb0..7bdd2619529d9 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -171,5 +171,55 @@ Mark a checkpoint for checkpointing memory optimization. return outputs; }); +RELAY_REGISTER_OP("annotation.compiler_begin") +.describe(R"code( +Beginning of a region that is handled by a given compiler. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_API("relay.op.annotation._make.compiler_begin") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_node(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_begin"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("annotation.compiler_end") +.describe(R"code( +End of a region that is handled by a given compiler. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_API("relay.op.annotation._make.compiler_end") +.set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_node(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_end"); + return CallNode::make(op, {expr}, Attrs(attrs), {}); +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/annotate_compiler.cc b/src/relay/pass/annotate_compiler.cc new file mode 100644 index 0000000000000..9cbdde3e361f8 --- /dev/null +++ b/src/relay/pass/annotate_compiler.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/pass/annotate_compiler.cc + * \brief Wraps a call with compiler_begin and compiler_end to indicate that + * the op of this call node will use external compiler. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace annotate_compiler { + +// A helper class to insert annotation boundaries for a program region that will +// be handled by a specific compiler. +class AnnotateCompilerWrapper : public ExprMutator { + public: + explicit AnnotateCompilerWrapper(const std::string& compiler) : compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* cn) { + auto new_e = ExprMutator::VisitExpr_(cn); + + Call call = Downcast(new_e); + static auto fannotate = Op::GetAttr("FTVMAnnotateCompiler"); + Op op = Downcast(call->op); + CHECK(op.operator->()); + + if (fannotate.count(op)) { + bool external = fannotate[op](call->attrs, call->args, compiler_); + if (external) { + tvm::Array compiler_begins; + for (const auto& it : call->args) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(it, compiler_); + compiler_begins.push_back(begin); + } + Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs); + const auto* end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(update_call, compiler_); + return end; + } + } else { + LOG(WARNING) << op.operator->()->name << " in " << compiler_ << " is not registered"; + } + return new_e; + } + + private: + std::string compiler_; +}; + +Expr AnnotateCompiler(const Expr& expr, const std::string& compiler) { + return AnnotateCompilerWrapper(compiler).Mutate(expr); +} + +} // namespace annotate_compiler + +namespace transform { + +Pass AnnotateCompiler(const std::string& compiler) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::annotate_compiler::AnnotateCompiler(f, compiler)); + }; + auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateCompilerFunc", + {ir::StringImm::make("InferType")}); + return transform::Sequential({func_pass, InferType()}, "AnnotateCompiler"); +} + +TVM_REGISTER_API("relay._transform.AnnotateCompiler") +.set_body_typed(AnnotateCompiler); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc new file mode 100644 index 0000000000000..c4ad797bad262 --- /dev/null +++ b/src/relay/pass/partition_graph.cc @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/pass/partition.cc + * + * \brief Partition an input function into multiple functions according based + * on the inserted annotation nodes (i.e. compiler_begin and compiler_end). + * These nodes are used as boundaries to partition the Relay function into + * multiple regions that can be offloaded to different accelerators/backends. + * + * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * external functions, and they will use the provided compiler for codegen. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace partitioning { + +/*! + * \brief The subgraph properties for partitioning. + */ +struct Subgraph { + /*! \brief The subgraph ID. */ + int id; + + /*! \brief The input arguments of this subgraph. */ + std::vector> args; + + /*! \brief Nodes in this subgraph. */ + std::unordered_set nodes; +}; + +/*! + * \brief The checker that verifies if a Relay program is annotated correctly + * for partitioning. + */ +class AnnotationChecker : public ExprVisitor { + public: + bool Check() { + if (!this->found_start && !this->found_end) { + LOG(WARNING) << "No compiler annotation found"; + } else if (!this->found_start) { + LOG(ERROR) << "compiler_begin annotation is missing"; + return false; + } else if (!this->found_end) { + LOG(ERROR) << "compiler_end annotation is missing"; + return false; + } + return true; + } + + void VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + if (op_node == nullptr || call->attrs.as() == nullptr) { + return; + } else if (GetRef(op_node) == Op::Get("annotation.compiler_begin")) { + this->found_start = true; + } else if (GetRef(op_node) == Op::Get("annotation.compiler_end")) { + this->found_end = true; + } + } + + private: + bool found_start = false; + bool found_end = false; +}; + +/*! \brief This class partitions the expr labeled with begin and end annoations + * into function containing multiple regions. Each region is labeled with + * a compiler attribute so that it will be handled by any compilers that are not + * in the TVM stack. + * + * TODO(@zhiics) This following algorithm is not adequate to handle all cases, + * i.e. multiple `compiler_end` nodes. + */ +class Partitioner : public ExprMutator { + public: + Subgraph* GetSubgraph(const Expr node) { + for (auto candidate : this->subgraphs_) { + if (candidate->nodes.find(node) != candidate->nodes.end()) { + return candidate; + } + } + return nullptr; + } + + void MergeSubgraph(Subgraph* subgraph1, Subgraph* subgraph2) { + if (subgraph1 == subgraph2) { + return; + } + + // Merge subgraph 2 to subgraph 1 and erase subgraph 2. + subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); + for (auto arg : subgraph2->args) { + subgraph1->args.push_back(arg); + } + this->subgraphs_.erase(subgraph2); + } + + void AddToSubgraph(Subgraph* subgraph, const Expr expr) { + auto subgraph2 = GetSubgraph(expr); + if (subgraph2) { + MergeSubgraph(subgraph, subgraph2); + } else { + subgraph->nodes.insert(expr); + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Propogate subgraph to arguments + auto subgraph = GetSubgraph(GetRef(call)); + if (subgraph) { + for (auto arg : call->args) { + AddToSubgraph(subgraph, arg); + } + } + return ExprMutator::VisitExpr_(call); + } else if (GetRef(op_node) == Op::Get("annotation.compiler_begin")) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Traverse the rest graph. + auto input_expr = VisitExpr(call->args[0]); + + // Replace the begin annotation with an external call input variable. + auto compiler_attrs = call->attrs.as(); + auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), + input_expr->checked_type_); + + // Find the corresponding subgraph and add the argument. + auto subgraph = GetSubgraph(GetRef(call)); + if (!subgraph) { + throw Error(RELAY_ERROR("Cannot find the corresponding subgraph for start annotation:\n" + << AsText(GetRef(call), false))); + } + subgraph->args.push_back({var, input_expr}); + return std::move(var); + } else { + CHECK(GetRef(op_node) == Op::Get("annotation.compiler_end")); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto compiler_attrs = call->attrs.as(); + + // Check if the argument is already belonged to an exist subgraph + auto subgraph = GetSubgraph(call->args[0]); + if (!subgraph) { + auto ret = this->subgraphs_.emplace(new Subgraph()); + subgraph = *ret.first; + subgraph->nodes.insert(call->args[0]); + subgraph->id = this->subgraph_id_++; + } + subgraph->nodes.insert(GetRef(call)); + + // Traverse towarding to subgraph inputs. + auto input = VisitExpr(call->args[0]); + Array params; + Array args; + + // The subgraph may be merged so we need to update it again. + subgraph = GetSubgraph(GetRef(call)); + for (auto pair : subgraph->args) { + params.push_back(pair.first); + args.push_back(pair.second); + } + + auto subgraph_func = + FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); + + Expr arg0 = call->args[0]; + std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); + subgraph_func = + FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImm::make(name)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, + tvm::ir::StringImm::make(compiler_attrs->compiler)); + return CallNode::make(subgraph_func, args); + } + } + + Expr VisitExpr_(const TupleNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + for (auto field : op->fields) { + AddToSubgraph(subgraph, field); + } + Array fields; + for (auto field : op->fields) { + fields.push_back(VisitExpr(field)); + } + return TupleNode::make(fields); + } + } + + Expr VisitExpr_(const TupleGetItemNode* g) final { + auto subgraph = GetSubgraph(GetRef(g)); + if (!subgraph) { + return ExprMutator::VisitExpr_(g); + } else { + AddToSubgraph(subgraph, g->tuple); + auto t = VisitExpr(g->tuple); + return TupleGetItemNode::make(t, g->index); + } + } + + Expr VisitExpr_(const FunctionNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + Array params; + for (auto param : op->params) { + AddToSubgraph(subgraph, param); + } + for (auto param : op->params) { + Var new_param = Downcast(VisitExpr(param)); + params.push_back(new_param); + } + auto body = VisitExpr(op->body); + return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); + } + } + + Expr VisitExpr_(const LetNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->var); + AddToSubgraph(subgraph, op->value); + AddToSubgraph(subgraph, op->body); + Var var = Downcast(VisitExpr(op->var)); + auto value = VisitExpr(op->value); + auto body = VisitExpr(op->body); + + return LetNode::make(var, value, body); + } + } + + Expr VisitExpr_(const IfNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->cond); + AddToSubgraph(subgraph, op->true_branch); + AddToSubgraph(subgraph, op->false_branch); + auto guard = VisitExpr(op->cond); + auto true_b = VisitExpr(op->true_branch); + auto false_b = VisitExpr(op->false_branch); + return IfNode::make(guard, true_b, false_b); + } + } + + Expr VisitExpr_(const RefCreateNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->value); + Expr value = VisitExpr(op->value); + return RefCreateNode::make(value); + } + } + + Expr VisitExpr_(const RefReadNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + return RefReadNode::make(ref); + } + } + + Expr VisitExpr_(const RefWriteNode* op) final { + auto subgraph = GetSubgraph(GetRef(op)); + if (!subgraph) { + return ExprMutator::VisitExpr_(op); + } else { + AddToSubgraph(subgraph, op->ref); + Expr ref = VisitExpr(op->ref); + Expr value = VisitExpr(op->value); + return RefWriteNode::make(ref, value); + } + } + + private: + int var_id_{0}; + int subgraph_id_{0}; + std::unordered_set subgraphs_; +}; + +/*! + * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to + * the same codegen backend. This reduces rounds trips between TVM and external + * backends. Likely we can borrow some ideas from operator fusion. + * + * For example, sg1 and sg2 should be combined if they belong to the same + * codegen tool in the following case. + * + * op1 + * / \ + * sg1 sg2 + * + * | + * \|/ + * + * op1 + * | + * sg1_sg2 + * + * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two + * inputs that obtained from the tuple. + */ + +Expr PartitionGraph(const Expr& expr) { + Partitioner part; + return part.Mutate(expr); +} + +} // namespace partitioning + +namespace transform { + +Pass PartitionGraph() { + runtime::TypedPackedFunc part_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(partitioning::PartitionGraph(f)); + }; + auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); + return Sequential({partitioned, InferType()}); +} + +TVM_REGISTER_API("relay._transform.PartitionGraph") +.set_body_typed(transform::PartitionGraph); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py new file mode 100644 index 0000000000000..66c1200bc8629 --- /dev/null +++ b/tests/python/relay/test_pass_partition_graph.py @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for graph partitioning.""" +import os +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform +from tvm import relay +from tvm.contrib import util +from tvm.relay.annotation import compiler_begin, compiler_end +from tvm.relay.expr_functor import ExprMutator + + +class CcompilerAnnotator(ExprMutator): + """ + A simple annotator that creates the following program: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(CcompilerAnnotator, self).__init__() + self.in_compiler = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_compiler == 1: + lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") + rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") + op = relay.add(lhs, rhs) + self.in_compiler = 2 + return op + elif call.op.name == "subtract": + if self.in_compiler == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_compiler = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + op = relay.multiply(lhs, rhs) + if self.in_compiler == 2: + op = compiler_end(op, "ccompiler") + self.in_compiler = 0 + return op + return super().visit_call(call) + + +class WholeGraphAnnotator(ExprMutator): + """ + An annotator that creates a compiler for an entire graph. + """ + + def __init__(self, compiler): + super(WholeGraphAnnotator, self).__init__() + self.compiler = compiler + self.last_call = True + + def visit_call(self, call): + curr_last = self.last_call + self.last_call = False + + params = [] + for arg in call.args: + param = super().visit(arg) + if isinstance(param, relay.expr.Var): + param = compiler_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + if curr_last: + new_call = compiler_end(new_call, self.compiler) + return new_call + + +class MobileNetAnnotator(ExprMutator): + """ + Annotate mobilenet until global_avg_pool. + """ + + def __init__(self, compiler): + super(MobileNetAnnotator, self).__init__() + self.compiler = compiler + self.compiler_open = False + + def visit_call(self, call): + + if call.op.name == 'nn.global_avg_pool2d': + self.compiler_open = True + compiler_open = self.compiler_open + + params = [] + for arg in call.args: + param = super().visit(arg) + if call.op.name == 'nn.global_avg_pool2d': + param = compiler_end(param, self.compiler) + if compiler_open and isinstance(param, relay.expr.Var): + param = compiler_begin(param, self.compiler) + params.append(param) + + new_call = relay.Call(call.op, params, call.attrs) + return new_call + + +def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", + ctx=tvm.cpu(), params=None): + if sys.platform == "win32": + print("Skip test on Windows for now") + return + + def update_lib(lib): + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.join(test_dir, "..", "..", "..") + contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") + + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = tvm.module.load(lib_path) + + return lib + + def check_vm_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target, params=params) + code, lib = exe.save() + lib = update_lib(lib) + exe = relay.vm.Executable.load_exec(code, lib) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + def check_graph_runtime_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, param = relay.build(mod, target=target, params=params) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.set_input(**param) + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + check_vm_result() + check_graph_runtime_result() + + +def test_multi_node_compiler(): + x = relay.var('x', shape=(10, 10)) + w0 = relay.var('w0', shape=(10, 10)) + w1 = relay.var('w1', shape=(10, 10)) + w2 = relay.var('w2', shape=(10, 10)) + w3 = relay.var('w3', shape=(10, 10)) + w4 = relay.var('w4', shape=(10, 10)) + w5 = relay.var('w5', shape=(10, 10)) + w6 = relay.var('w6', shape=(10, 10)) + w7 = relay.var('w7', shape=(10, 10)) + + # C compiler + # FIXME: We generate two compilers for this case but they should be merged to one + # due to the common input (x). + z0 = relay.add(x, w0) + p0 = relay.subtract(z0, w1) + q0 = relay.multiply(p0, w2) + + z1 = relay.add(x, w3) + p1 = relay.subtract(z1, w4) + q1 = relay.multiply(p1, w5) + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((q0, q1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = relay.Module() + ann = CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype('float32') + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype('float32')) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + check_result( + mod, map_inputs, (30, 10), + np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7]), + axis=0)) + + +def test_extern_ccompiler_single_op(): + x = relay.var('x', shape=(8, 8)) + y = relay.var('y', shape=(8, 8)) + z = x + y + f = relay.Function([x, y], z) + x_data = np.random.rand(8, 8).astype('float32') + y_data = np.random.rand(8, 8).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = relay.build_extern_compiler(mod, "ccompiler") + + check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) + + +def test_extern_ccompiler(): + x = relay.var('x', shape=(2, 2)) + y = relay.var('y', shape=(2, 2)) + z = x + x + p = y * y + f = relay.Function([x, y], p - z) + x_data = np.random.rand(2, 2).astype('float32') + y_data = np.random.rand(2, 2).astype('float32') + mod = relay.Module() + mod["main"] = f + mod = relay.build_extern_compiler(mod, "ccompiler") + + check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + + +def test_extern_dnnl(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data = relay.var('data', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data, weight1], out) + + mod = relay.Module() + mod['main'] = WholeGraphAnnotator('dnnl').visit(f) + mod = relay.transform.PartitionGraph()(mod) + + ref_mod = relay.Module() + ref_mod['main'] = f + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data, w1_data) + check_result(mod, {"data": i_data, "weight1": w1_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + +def test_extern_dnnl_mobilenet(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 3, 224, 224) + mod, params = relay.testing.mobilenet.get_workload( + batch_size=1, dtype='float32') + + mod = relay.build_extern_compiler(mod, "dnnl") + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, + dtype='float32') + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0)) + ref_res = ref_ex.evaluate()(i_data, **params) + + check_result(mod, {"data": i_data}, + (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) + + + +if __name__ == "__main__": + test_multi_node_compiler() + test_extern_ccompiler_single_op() + test_extern_ccompiler() + test_extern_dnnl() + test_extern_dnnl_mobilenet()