From 97d642f52fc7f9faedadfee87e411f0fe49cf133 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Fri, 24 Sep 2021 13:19:13 -0700 Subject: [PATCH] [Relay] Prepare for merging context_analysis.cc and device_annotation.cc (#9077) * [Relay] Prepare for merging context_analysis.cc and device_annotation.cc - Improve construction and deconstruction of "on_device" and "device_copy" calls since they will be center stage. - Move "device_copy" support out of memory.h into own module to mirror "on_device". - Clearing out some DLOG -> VLOG changes I found helped me debug. - Clearing out some whitespace-only changes I accumulated. * [checkpoint] Address Christopher's comments. Some stray py formatting changes snuck in since I just run black . at the root. --- include/tvm/relay/attrs/annotation.h | 28 +++- include/tvm/relay/attrs/device_copy.h | 1 + include/tvm/relay/attrs/function.h | 66 +++++++++ include/tvm/relay/expr_functor.h | 3 +- include/tvm/relay/transform.h | 4 +- include/tvm/runtime/container/array.h | 2 +- include/tvm/runtime/ndarray.h | 26 ++-- python/tvm/relay/op/annotation/annotation.py | 56 ++++++-- python/tvm/relay/transform/transform.py | 2 +- python/tvm/target/target.py | 8 +- src/node/structural_equal.cc | 7 +- src/relay/backend/te_compiler.cc | 8 +- src/relay/backend/vm/inline_primitives.cc | 4 +- src/relay/op/annotation/annotation.cc | 136 +++++++++++++++++- src/relay/op/annotation/annotation.h | 116 +++++++++++++++ src/relay/op/memory/device_copy.cc | 117 +++++++++++++++ src/relay/op/memory/device_copy.h | 79 ++++++++++ src/relay/op/memory/memory.cc | 44 ++---- src/relay/op/memory/memory.h | 1 - src/relay/quantize/partition.cc | 2 +- src/relay/quantize/realize.cc | 2 +- src/relay/transforms/device_annotation.cc | 2 +- src/relay/transforms/memory_alloc.cc | 14 +- src/runtime/ndarray.cc | 4 +- src/tir/analysis/verify_memory.cc | 5 +- .../relay/op/annotation/test_annotation.py | 71 +++++++++ 26 files changed, 709 insertions(+), 99 deletions(-) create mode 100644 include/tvm/relay/attrs/function.h create mode 100644 src/relay/op/annotation/annotation.h create mode 100644 src/relay/op/memory/device_copy.cc create mode 100644 src/relay/op/memory/device_copy.h create mode 100644 tests/python/relay/op/annotation/test_annotation.py diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 8379e6471561..bc55965ee852 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -32,15 +32,37 @@ namespace tvm { namespace relay { /*! - * \brief Options for the device annotation operators. + * \brief Attributes for the "on_device" operator. + * + * The relay call + * \code + * on_device(expr, device_type=2) + * \endcode + * denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2 + * (i.e. \p kDLCuda). Semantically the operator is the identity function. + * + * See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level + * companion. */ struct OnDeviceAttrs : public tvm::AttrsNode { - int device_type; + // TODO(mbs): Replace device types with TargetDevice. + /*! \brief Device type on which argument expression should be evaluated. */ + int device_type = kInvalidDeviceType; + /*! + * \brief If true, the result device must also be \p device_type and device planning should + * not insert any "device_copy" calls to respect this annotation. + * + * This is used by the device planning pass itself when annotating the planned program. + */ + bool is_fixed = false; TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe("The virutal device/context type that an expression is annotated with.") + .describe("The type of the virtual device which should hold the expression result.") .set_default(0); + TVM_ATTR_FIELD(is_fixed) + .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + .set_default(false); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 7da92b3ff763..f7b0a04f45fa 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -35,6 +35,7 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { + // TODO(mbs): Should be TargetDevice. int dst_dev_type; int src_dev_type; diff --git a/include/tvm/relay/attrs/function.h b/include/tvm/relay/attrs/function.h new file mode 100644 index 000000000000..f4f94131da1f --- /dev/null +++ b/include/tvm/relay/attrs/function.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/function.h + * \brief Attributes for Relay Functions which don't make sense on PrimFuncs. + */ +#ifndef TVM_RELAY_ATTRS_FUNCTION_H_ +#define TVM_RELAY_ATTRS_FUNCTION_H_ + +namespace tvm { +namespace relay { +/*! + * \brief Attributes for Relay function definitions which capture the devices for the + * function parameters and result. + * + * See also OnDeviceAttrs in include/tvm/relay/attrs/annotation.h for the companion "on_device" + * call attributes. + */ +struct FunctionOnDeviceAttrs : public tvm::AttrsNode { + /*! \brief Device type on which each of the function's arguments already resides. */ + Array param_device_types; + // TODO(mbs): Replace device types with TargetDevice. + /*! \brief Device type on which function body should be evaluated. */ + int result_device_type = kInvalidDeviceType; + + TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") { + TVM_ATTR_FIELD(param_device_types) + .describe("The type of the virtual device which holds each function parameters."); + TVM_ATTR_FIELD(result_device_type) + .describe("The type of the virtual device which will hold the function's result.") + .set_default(0); + } +}; + +namespace attr { + +/*! + * \brief Device annotations for function parameters and results. + * + * Type: FunctionOnDeviceAttrs + */ +constexpr static const char* kFunctionAttrsKey = "on_device"; + +} // namespace attr + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ATTRS_FUNCTION_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 688ad8254fa8..f96faffb24f4 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -37,6 +37,7 @@ #include #include #include + namespace tvm { namespace relay { @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor { * * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions - * of the graph and processes them iteratatively to prevent stack overflows + * of the graph and processes them iteratively to prevent stack overflows */ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { public: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 912879dc8a4b..cdd4c9c1dbd2 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -437,8 +437,8 @@ TVM_DLL Pass RelayToTIRTargetHook(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param target_host The target used by the host for compliation. - * \param targets The device type and target pairs for compliation. + * \param target_host The target used by the host for compilation. + * \param targets The device type and target pairs for compilation. * * \return The pass. */ diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 8830653da88c..26f4e545deb7 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase { }; /*! - * \brief Array, container representing a contigious sequence of ObjectRefs. + * \brief Array, container representing a contiguous sequence of ObjectRefs. * * Array implements in-place copy-on-write semantics. * diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 1127a9ae732c..a4c285e3dd08 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -38,9 +38,19 @@ #include namespace tvm { -namespace runtime { -typedef DLDevice Device; +// alias DLDevice +using Device = DLDevice; + +// A 'null' device type, does not correspond to any DLDeviceType enum. +// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case +// as a singleton target map indexed by the invalid DLDeviceType '0'. +constexpr DLDeviceType kNullDeviceType = static_cast(0); + +// An 'invalid' device type, does not correspond to any DLDeviceType enum. +constexpr DLDeviceType kInvalidDeviceType = static_cast(-1); + +namespace runtime { /*! * \brief Managed NDArray. @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) { } } // namespace runtime - -// alias Device -using tvm::runtime::Device; - } // namespace tvm namespace std { template <> -struct hash { - std::size_t operator()(const tvm::runtime::Device& dev) const { +struct hash { + std::size_t operator()(const tvm::Device& dev) const { return ((dev.device_id << 8) | dev.device_type); } }; template <> -struct equal_to { - bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const { +struct equal_to { + bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const { return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id); } }; diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 809b6369b085..f5f8870ab015 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -22,8 +22,16 @@ from .. import op as reg -def on_device(data, device): - """Annotate an expression with a certain device type. +def _device_to_int(device): + if isinstance(device, _Device): + return device.device_type + if isinstance(device, str): + return _nd.device(device).device_type + raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) + + +def on_device(data, device, is_fixed=False): + """Annotates an expression with the device type on which its result should be stored. Parameters ---------- @@ -31,23 +39,45 @@ def on_device(data, device): The expression to be annotated. device : Union[:py:class:`Device`, str] - The device type to annotate. + The device to annotate with. Only the device's type is significant. + + is_fixed : bool + If false (the default), a device_copy + If true, the annotation does not imply a device_copy may be inserted to + reconcile the device of the data argument with the device for the context of the + annotated expression. Returns ------- result : tvm.relay.Expr The annotated expression. """ - if isinstance(device, _Device): - device = device.device_type - elif isinstance(device, str): - device = _nd.device(device).device_type - else: - raise ValueError( - "device is expected to be the type of Device or " - "str, but received %s" % (type(device)) - ) - return _make.on_device(data, device) + return _make.on_device(data, _device_to_int(device), is_fixed) + + +def function_on_device(function, param_devices, result_device): + """Annotates a Relay function with the device types on which its parameters and result should + be stored. + + Parameters + ---------- + function : tvm.relay.Function + The function to be annotated. + + param_devices : Array[Union[:py:class:`Device`, str]] + The devices for each parameter. Only the device types are significant. + + result_device: Union[:py:class:`Device`, str] + The device for the function result. Only the device type is significant. + + Returns + ------- + result : tvm.rleay.Function + The annotated function. + """ + return _make.function_on_device( + function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device) + ) def stop_fusion(data): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 688422284c0f..7c79464bdd30 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -546,7 +546,7 @@ def MergeCompilerRegions(): def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. - `on_deivce`, mark which device an expression should be scheduled to. + `on_device`, mark which device an expression should be scheduled to. This pass helps heterogeneous execution where different operators may need to be allocated on various devices. diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index af2f5d857293..4ce888170134 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -525,7 +525,7 @@ def hexagon(cpu_ver="v66", **kwargs): # LLVM target string def create_llvm_target(cpu_ver, config): - """ Create LLVM target string. """ + """Create LLVM target string.""" target = " -mtriple=hexagon" mcpu = " -mcpu=hexagon" + cpu_ver @@ -547,7 +547,7 @@ def create_target_features(config): # Simulator options string def create_sim_options(cpu_ver, config): - """ Create simulator option string. """ + """Create simulator option string.""" def validate_hvx_length(codegen_hvx, sim_options): if sim_options and "--hvx_length" in sim_options: @@ -606,7 +606,7 @@ def validate_hvx_length(codegen_hvx, sim_options): # LLVM options string def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create LLVM options string. """ + """Create LLVM options string.""" llvm_options = config["llvm_options"] @@ -620,7 +620,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument # TVM target attributes string def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create TVM target features string. """ + """Create TVM target features string.""" features = { "link_params": "link-params", diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 1fa72c92b6fc..8e52af60d235 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,6 +19,7 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl + << PrettyPrint(lhs) << std::endl + << "and rhs:" << std::endl + << PrettyPrint(rhs); } return result; } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 0d32cc61e2e6..d37fbeabc277 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -439,11 +439,10 @@ class LowerTensorExprMutator : public ExprMutator { } // Non-External Relay Function - DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" - << PrettyPrint(func); + VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); - DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; + VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; // Collect all the lowered functions produced for this primitive function. Map prim_fns; @@ -452,8 +451,7 @@ class LowerTensorExprMutator : public ExprMutator { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); all_prim_fn_vars.push_back(prim_fn.first); - DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) - << "'"; + VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 05fb2a120620..6924f2598f6f 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -136,13 +136,13 @@ struct PrimitiveInliner : ExprMutator { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); + VLOG(1) << "Before inlining primitives: " << global << std::endl << PrettyPrint(func); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(global, func, true); - DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false); + VLOG(1) << "After inlining primitives: " << global << std::endl << PrettyPrint(func); } } return module_; diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index b59c5a3e9ff3..4eda15937f3a 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -20,10 +20,13 @@ /*! * * \file src/relay/op/annotation/annotation.cc - * \brief Registration of annotation operators. + * \brief Helpers for working with various 'annotations' attributes. */ +#include "./annotation.h" + #include +#include #include #include #include @@ -36,15 +39,51 @@ namespace tvm { namespace relay { -// relay.annotation.on_device TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); +const Op& OnDeviceOp() { + static const Op& op = Op::Get("on_device"); + return op; +} + +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + auto attrs = make_object(); + attrs->device_type = device_type; + attrs->is_fixed = is_fixed; + Span span = expr->span; + return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); +} + +Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + if (device_type == kInvalidDeviceType) { + // Undefined signals no annotation is required. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // These operators are device polymorphic so no annotation is required. + // TODO(mbs): The device planning pass does NOT currently support device polymorphism for + // constructors, so we could remove them from this condition. However most constructors + // accept type parameters, and it is not well-formed Relay to simply wrap such a + // constructor in an "on_device" call. So we'll pretend they are device polymorphic to + // avoid that difficultly. Overall ADTs need more work to be fully supported. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // The device can be recovered from the binding site of the global or local variable. + return expr; + } + if (const auto* function_node = expr.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Primitive functions are device polymorphic, matching our interpretation for OpNode above. + return expr; + } + } + return OnDevice(expr, device_type, is_fixed); +} + TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); + .set_body_typed([](Expr expr, int device_type, bool is_fixed) { + return OnDevice(expr, static_cast(device_type), is_fixed); }); RELAY_REGISTER_OP("on_device") @@ -56,12 +95,95 @@ RELAY_REGISTER_OP("on_device") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("TNonComputational", true) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) -> Array { return {topi::identity(inputs[0])}; }); +OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + if (call_node->op == OnDeviceOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; + ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; + const auto* on_device_attrs = call_node->attrs.as(); + ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; + auto device_type = static_cast(on_device_attrs->device_type); + // Follow nesting: + // on_device(on_device(expr, device_type=1), device_type=2) == {expr, 1} + auto inner = GetOnDeviceProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.device_type, on_device_attrs->is_fixed || inner.is_fixed}; + } else { + return {call_node->args[0], device_type, on_device_attrs->is_fixed}; + } + } + return {}; +} + +OnDeviceProps GetOnDeviceProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetOnDeviceProps(call_node); + } + return {}; +} + +TVM_REGISTER_NODE_TYPE(FunctionOnDeviceAttrs); + +Function FunctionOnDevice(Function function, Array param_device_types, + DLDeviceType result_device_type) { + auto attrs = make_object(); + attrs->param_device_types = std::move(param_device_types); + attrs->result_device_type = result_device_type; + return WithAttr(std::move(function), attr::kFunctionAttrsKey, Attrs(std::move(attrs))); +} + +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType result_device_type) { + Array arr; + arr.reserve(param_device_types.size()); + for (const auto device_type : param_device_types) { + arr.push_back(static_cast(device_type)); + } + return FunctionOnDevice(function, arr, result_device_type); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") + .set_body_typed([](Function function, Array param_device_types, + int result_device_type) { + return FunctionOnDevice(function, param_device_types, + static_cast(result_device_type)); + }); + +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { + auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); + if (!opt_attrs) { + // No annotation. + return kInvalidDeviceType; + } + const auto* opt_function_on_device_attrs = opt_attrs.value().as(); + ICHECK(opt_function_on_device_attrs != nullptr) + << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; + return static_cast(opt_function_on_device_attrs->result_device_type); +} + +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { + ICHECK_LT(i, function_node->params.size()) + << "param index " << i << " out of range for function of arity " + << function_node->params.size(); + auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); + if (!opt_attrs) { + // No annotation. + return kInvalidDeviceType; + } + const auto* opt_function_on_device_attrs = opt_attrs.value().as(); + ICHECK(opt_function_on_device_attrs != nullptr) + << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; + ICHECK_EQ(opt_function_on_device_attrs->param_device_types.size(), function_node->params.size()) + << "annotation parameters do not match function arity"; + return static_cast(opt_function_on_device_attrs->param_device_types[i]->value); +} + Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h new file mode 100644 index 000000000000..e3a4aea4708c --- /dev/null +++ b/src/relay/op/annotation/annotation.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/annotation/annotation.h + * \brief Helpers for working with various 'annotation' attributes. + */ +#ifndef TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ +#define TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "on_device" operator. */ +const Op& OnDeviceOp(); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. + */ +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. However + * returns \p expr directly if: + * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations + * already in play. + * - \p expr is an operator or primitive function literal. These are device polymorphic. + * - \p expr is a global or local var. These already have an implied device. + * - \p expr is a constructor. There should probably be device polymorphic but are in an + * in-between state at the moment. + */ +Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + DLDeviceType device_type = kInvalidDeviceType; + bool is_fixed = false; + + OnDeviceProps() = default; + + OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed) + : body(body), device_type(deviceType), is_fixed(isFixed) {} +}; + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is + * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p + * false. + */ +OnDeviceProps GetOnDeviceProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an + * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false. + */ +OnDeviceProps GetOnDeviceProps(const Expr& expr); + +/*! \brief Returns true if \p expr is an on_device CallNode. */ +inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } + +/*! + * \brief Returns \p function annotated with "on_device" attributes capturing parameter and result + * devices types. However returns \p function directly if all device types are \p + * kInvalidDeviceType. + */ +Function FunctionOnDevice(Function function, Array param_device_types, + DLDeviceType body_device_type); +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType body_device_type); + +/*! + * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType + * if function does not have "on_device" annotation. + */ +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); + +/*! + * \brief Returns the device type for the \p i'th parameter of \p function_node, or + * \p kInvalidDeviceType if function does not have "on_device" annotation. + */ +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); + +/*! \brief Wraps \p data in a "stop_fusion" annotation. */ +Expr StopFusion(Expr data); + +/*! \brief Wraps \p data in a "cast_hint" annotation for \p dtype. */ +Expr CastHint(Expr data, DataType dtype); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc new file mode 100644 index 000000000000..b94caac2c3d9 --- /dev/null +++ b/src/relay/op/memory/device_copy.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/device_copy.cc + * \brief Helpers for working with "device_copy" attributes. + */ + +#include "./device_copy.h" + +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// relay.device_copy +TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); + +const Op& DeviceCopyOp() { + static const Op& op = Op::Get("device_copy"); + return op; +} + +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + Span span = expr->span; + return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span); +} + +Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + if (src_dev_type == dst_dev_type) { + return expr; + } + ICHECK_NE(src_dev_type, kInvalidDeviceType); + ICHECK_NE(dst_dev_type, kInvalidDeviceType); + return DeviceCopy(expr, src_dev_type, dst_dev_type); +} + +TVM_REGISTER_GLOBAL("relay.op._make.device_copy") + .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) { + return DeviceCopy(expr, static_cast(src_dev_type), + static_cast(dst_dev_type)); + }); + +RELAY_REGISTER_OP("device_copy") + .describe(R"code( +Copy data from one tensor to another. The source and destination might be +on different devices. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { + if (call_node->op == DeviceCopyOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument"; + ICHECK(call_node->attrs.defined()) << "device_copy requires attributes"; + const auto* device_copy_attrs = call_node->attrs.as(); + ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; + auto src_dev_type = static_cast(device_copy_attrs->src_dev_type); + auto dst_dev_type = static_cast(device_copy_attrs->dst_dev_type); + // Follow nesting: + // device_copy(device_copy(expr, src_dev_type=1, dst_dev_type=2), + // src_dev_type=2, dst_dev_type=3) ==> {expr, 1, 3} + auto inner = GetDeviceCopyProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.src_dev_type, inner.dst_dev_type}; + } else { + return {call_node->args[0], src_dev_type, dst_dev_type}; + } + } + return {}; +} + +DeviceCopyProps GetDeviceCopyProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetDeviceCopyProps(call_node); + } + return {}; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h new file mode 100644 index 000000000000..d590d8510f17 --- /dev/null +++ b/src/relay/op/memory/device_copy.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/device_copy.h + * \brief Helpers for working with "device_copy" attributes. + */ + +#ifndef TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ +#define TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "device_copy" operator. */ +const Op& DeviceCopyOp(); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + */ +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. + */ +Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! \brief Result of \p GetDeviceCopyProps. */ +struct DeviceCopyProps { + Expr body; // = null + DLDeviceType src_dev_type = kInvalidDeviceType; + DLDeviceType dst_dev_type = kInvalidDeviceType; + + DeviceCopyProps() = default; + + DeviceCopyProps(const Expr& body, DLDeviceType srcDevType, DLDeviceType dstDevType) + : body(body), src_dev_type(srcDevType), dst_dev_type(dstDevType) {} +}; + +/*! + * \brief Returns the body expression, source, and destination device types for \p call_node if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, source, and destination device types for \p expr if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const Expr& expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c2997fb6cf95..68a83ebba1fe 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -35,9 +35,9 @@ #include #include "../../transforms/infer_layout_utils.h" +#include "../annotation/annotation.h" #include "../op_common.h" #include "../type_relations.h" -#include "tvm/relay/attrs/device_copy.h" namespace tvm { namespace relay { @@ -97,14 +97,21 @@ RELAY_REGISTER_OP("memory.alloc_storage") return {topi::identity(inputs[0])}; }); -Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, +Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; if (assert_shape.defined()) { attrs->assert_shape = assert_shape; } else { - attrs->const_shape = Downcast(shape); + // Look through any on_device for the shape argument expression. + Expr literal_shape = shape; + auto props = GetOnDeviceProps(literal_shape); + if (props.body.defined()) { + // See through on_device calls. + literal_shape = props.body; + } + attrs->const_shape = Downcast(literal_shape); } static const Op& op = Op::Get("memory.alloc_tensor"); return Call(op, {storage, offset, shape}, Attrs(attrs), {}); @@ -307,36 +314,5 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType") return ToTupleType(t, std::vector(array.begin(), array.end())); }); -// relay.device_copy -TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); - -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.device_copy").set_body_typed(DeviceCopy); - -RELAY_REGISTER_OP("device_copy") - .describe(R"code( -Copy data from one tensor to another. The source and destination might be -on different devices. -)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") - .set_support_level(10) - .add_type_rel("Identity", IdentityRel) - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index bbbd11867549..558c409782f5 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -33,7 +33,6 @@ namespace tvm { namespace relay { Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint); -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index c65cc1879932..6cd596a814ac 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -26,7 +26,7 @@ #include -#include "../transforms/pattern_utils.h" +#include "../op/annotation/annotation.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 968628fbfe39..e636130f8553 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -29,8 +29,8 @@ #include #include +#include "../op/annotation/annotation.h" #include "../qnn/utils.h" -#include "../transforms/pattern_utils.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 02f9d474411a..7457457e4c5c 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -18,7 +18,7 @@ */ /*! - * \file deivce_annotation.cc + * \file device_annotation.cc * \brief Passes to rewrite annotated program and retrieve the device allocation * of expression. * diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 657e2c392455..31d3b2c8991a 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -43,14 +43,15 @@ #include "../backend/te_compiler.h" #include "../backend/te_compiler_cache.h" +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./let_list.h" #include "./pass_utils.h" -#include "let_list.h" -#include "pattern_utils.h" +#include "./pattern_utils.h" using namespace tvm::runtime; -using namespace tvm::relay::tec; namespace tvm { namespace relay { @@ -193,7 +194,8 @@ class DialectRewriter : public ExprMutator { private: // Insert a device copy node. Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return ExprMutator::Mutate(relay::DeviceCopy(inp, src_dev, dst_dev)); + return ExprMutator::Mutate(relay::DeviceCopy(inp, static_cast(src_dev), + static_cast(dst_dev))); } // Check if a call invokes a primitive function. @@ -274,9 +276,9 @@ class DialectRewriter : public ExprMutator { const std::vector& new_args) { Array shape_func_ins; - TECompiler compiler; + tec::TECompiler compiler; - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 968a4488bbcf..8db89c59a85d 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -272,7 +272,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ dtype.code = static_cast(dtype_code); dtype.bits = static_cast(dtype_bits); dtype.lanes = static_cast(dtype_lanes); - Device dev; + tvm::Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); @@ -286,7 +286,7 @@ TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, int ndim = args[1]; ShapeTuple shape(shape_ptr, shape_ptr + ndim); DataType dtype = args[2]; - Device dev = args[3]; + tvm::Device dev = args[3]; Optional mem_scope = args[4]; auto ndarray = NDArray::Empty(shape, dtype, dev, mem_scope); *ret = ndarray; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 0382b8071de7..b6c41b958c31 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -172,8 +172,9 @@ std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; - DLOG(INFO) << "verifying memory for target '" << target.value()->str() << "' for primitive\n" - << PrettyPrint(func); + VLOG(1) << "verifying memory for target '" << target.value()->str() + << "' for primitive:" << std::endl + << PrettyPrint(func); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py new file mode 100644 index 000000000000..51daa9aaa06a --- /dev/null +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for annotations.""" +import tvm +from tvm import relay +import pytest + + +def test_on_device_via_string(): + x = relay.Var("x") + call = relay.annotation.on_device(x, "cuda") + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.device_type == 2 # ie kDLCUDA + assert not call.attrs.is_fixed + + +def test_on_device_via_device(): + x = relay.Var("x") + call = relay.annotation.on_device(x, tvm.device("llvm")) + assert call.attrs.device_type == 1 # ie kDLCPU + + +def test_on_device_invalid_device(): + x = relay.Var("x") + pytest.raises(ValueError, lambda: relay.annotation.on_device(x, "bogus")) + + +def test_on_device_is_fixed(): + x = relay.Var("x") + call = relay.annotation.on_device(x, "cuda", True) + assert call.attrs.device_type == 2 + assert call.attrs.is_fixed + + +def test_function_on_device(): + x = relay.Var("x") + y = relay.Var("y") + f = relay.Function([x, y], relay.add(x, y)) + func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") + assert isinstance(func, relay.Function) + assert len(func.attrs["on_device"].param_device_types) == 2 + assert func.attrs["on_device"].param_device_types[0] == 1 + # ie kDLCPU + assert func.attrs["on_device"].param_device_types[1] == 2 + # ie kDLCUDA + assert func.attrs["on_device"].result_device_type == 2 + # ie KDLCUDA + + +if __name__ == "__main__": + test_on_device_via_string() + test_on_device_via_device() + test_on_device_invalid_device() + test_on_device_is_fixed() + test_function_on_device()