From c00a566aa65955dd04b6a5b3cd72568ceece8e8c Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 18 Oct 2021 17:34:51 -0700 Subject: [PATCH] Adds SEScope (Storage/Execution Scope) for use as new unit of planning in 'device' planning This is the first step in https://github.com/apache/tvm-rfcs/pull/38 to bring devices and targets together when doing device planning. I've gone ahead and also included a memory scope in this object since we will also need to propagate memory scopes across Relay expressions once this basic preparation is in place. In the meantime that field will be left as "". Once device planning works in units of SEScopes it will be possible to directly read off the device and target for any Relay sub-expression without the need for TargetMaps ort the construction of default Targets. SEScopes also support 'Join' and 'Default' operations needed when constraint solving in the device planner. You can see those in use in my scratchpad branch: https://github.com/mbs-octoml/mbs-tvm/tree/mbs-scopes This PR also brings some duplicated and the ad-hoc 'default target' handling logic together into a CompilationConfig class. (Again, see the scratchpad branch for how that will end up being used). I've placed that next to SEScope since it's main purpose is to a) establish the default SEScope for primitive ops b) establish the SEScope for the 'host' c) feed a definitive vector of Targets into device planning so it can resolve all "on_device" and "device_copy" device references to their full SEScope form. --- include/tvm/ir/attrs.h | 4 +- include/tvm/target/se_scope.h | 384 +++++++++++++++++++ include/tvm/target/target.h | 13 +- python/tvm/target/__init__.py | 1 + python/tvm/target/se_scope.py | 22 ++ src/ir/attr_functor.h | 3 + src/parser/parser.cc | 4 +- src/printer/relay_text_printer.cc | 15 +- src/printer/text_printer.h | 1 + src/relay/backend/aot_executor_codegen.cc | 4 +- src/relay/backend/build_module.cc | 11 +- src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/te_compiler.h | 2 +- src/relay/backend/utils.h | 2 +- src/relay/backend/vm/compiler.cc | 16 +- src/relay/backend/vm/compiler.h | 7 +- src/target/se_scope.cc | 396 ++++++++++++++++++++ src/target/target.cc | 2 +- tests/cpp/target/se_scope_test.cc | 94 +++++ tests/python/target/test_se_scope.py | 52 +++ 20 files changed, 1000 insertions(+), 35 deletions(-) create mode 100644 include/tvm/target/se_scope.h create mode 100644 python/tvm/target/se_scope.py create mode 100644 src/target/se_scope.cc create mode 100644 tests/cpp/target/se_scope_test.cc create mode 100644 tests/python/target/test_se_scope.py diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 715c96eb6ea52..f6c15f9590df0 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -489,7 +489,7 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; - // applies two stratgies to lookup + // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h new file mode 100644 index 0000000000000..5573fd1f22ea3 --- /dev/null +++ b/include/tvm/target/se_scope.h @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.h + * \brief A compile time representation for a Storage or Execution Scope. + */ + +#ifndef TVM_TARGET_SE_SCOPE_H_ +#define TVM_TARGET_SE_SCOPE_H_ + +#include +#include + +#include +#include + +namespace tvm { + +class SEScope; + +/*! + * \brief Describes at compile time where data is to be stored down to the device and memory + * scope level, or where execution is to take place, down to the device level. It is a quadruple of: + * - A \p device_type (\p DLDeviceType). + * - An uninterpreted \p virtual_device_id (\p int) distinguishing the intended device from all + * other devices (either of the same \p device_type, or across all availabel devices in the + * system). The \p virtual_device_id may be left as 0 if not significant. It is up to downstream + * compilation passes and/or the runtime to map a \p virtual_device_id to an actual physical + * device if required. In particular the \p virtual_device_id need not correspond exactly to + * any runtime \p Device's \p device_id. + * - A \p target (\p Target) describing how to compile code for the intended device. The + * \p target->kind->device_type must match the above \p device_type. + * - A \p memory_scope (currently just \p String) describing which memory area is to be used to + * hold data. The area should be reachable from the device but need not be 'on' the device, + * see below. (We're using a \p String for now but would prefer a more structured representation.) + * + * All of these fields may be 'unconstrained', signaling that device planning is free to choose + * a value consistent with the whole program. However if a \p target is given then the + * \p device_type must equal \p target->kind->device_type. + * + * Note that currently we assume if a function returns its result on a particular device + * then the function body is also executed on that device. See the overview comment in + * src/relay/transforms/device_planner.cc for more details. + * + * By 'data' we include both tensors and additional supporting datastructures such as shapes, + * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside + * on a 'CPU'-like device with good support for scalars. + * + * By 'execution' we include both (fused) primitive operators, and all the Relay expressions + * surrounding them which coordinates data and control flow. Again, typically non-primitive + * operators must be executed on a 'CPU'-like device with good support for control flow. + * + * Targets vs Devices + * ------------------ + * Generally \p Targets (a compile-time only datastructue) describe compiler options for a specific + * microarchitecture and toolchain, while \p Devices (a runtime datastructure alsa available at + * compile time) describe a physical device on the target system. Obviously the target must agree + * with the device's microarchitecture, but we otherwise don't impose any constraints between them: + * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf + * out of a particular primitive. + * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. + * + * Traditionally TVM assumes at most one \p Target per \p DLviceType. We are moving away from that + * assumption. + * + * Memory scopes and devices + * ------------------------- + * Multi-device systems can have complex memory hierarchies. For example + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCPU, 1, "llvm", "global") + * \endcode + * could denote: + * - The same memory area accessible from two separate CPUs without any CPU affinity; + * - Distinct memory areas in a NUMA architecture for which cross-device access is handled + * by the memory system; + * - Outright distinct memory areas, where one device cannot directly address the memory of + * another. + * + * Similarly: + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCUDA, 0, "cuda", "host") + * \endcode + * could denote the same memory area, but with very different access costs. + * + * We don't currently try to build any of this system-level understanding into \p SEScope. Device + * planning will simply insert "device_copy" operators wherever \p SEScopes are not exactly + * pointwise equal, and we leave it to downstream compilation to elide unnecessary copies. We + * may revisit this in the future. + * + * Object identity + * --------------- + * \p SEScopes can only be constructed by the memoizing helpers. This means code can assume + * \code + * se_scope1 != se_scope2 => se_scope1 and se_scope2 differ pointwise + * \endcode + * This simplifies the device planner which needs to solve equality constraints between \p SEScopes. + * + * Joining and Defaulting + * ---------------------- + * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees + * with both join arguments. Eg: + * \code + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) + * => (kDLCPU, 3, "llvm", "global") + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "local)) + * => null (no join possible) + * \endcode + * + * Related to 'join' is 'default', which only takes constrained fields from the rhs when the + * lhs is unconstrained: + * \code + * Default(kDLCPU, -1, "llvm", "local"), (kDLCPU, 3, null, "global")) + * => (kDLCPU, 3, "llvm", "local") + * \endcode + * + * These operations are needed during device planning. + * + */ +class SEScopeNode : public Object { + public: + /*! + * \brief The \p DLDeviceType of the device. If \p target is known then this will be equal to + * \p target->kind->device_type. If \p target is null then the target is to be determined by + * a later pass. + * + * This is needed to support the legacy "on_device" and "device_copy" calls which only allow + * a \p DLDeviceTypes (as an integer) to be given. + * + * kInvalidDeviceType denotes unconstrained. + */ + DLDeviceType device_type() const { return device_type_; } + + /*! + * \brief The 'virtual' device identifier for the device. This must be resolved to a physical + * device identifier either during compilation or at runtime. + * + * -1 denotes unconstrained. May be 0 if not significant. + */ + int virtual_device_id() const { return virtual_device_id_; } + + /*! + * \brief The \p Target describing how to compile for the device. + * + * Null denotes unconstrained (though if device_type is known then only a target of that + * type is allowed). + */ + const Target& target() const { return target_; } + + /*! + * \brief The scope of memory within the device. + * + * Empty denotes unconstrained. + */ + // TODO(mbs): We are using String as a stand-in pending a more structured representation, such + // as runtime::StorageScope or a memory pool. + const String& memory_scope() const { return memory_scope_; } + + /*! + * \brief Returns true if scope is fully unconstrained, ie no target/device type, virtual device + * id or memory scope is specified. + */ + bool is_fully_unconstrained() const { + return !target_.defined() && device_type_ == kInvalidDeviceType && virtual_device_id_ == -1 && + memory_scope_.empty(); + } + + /*! + * \brief Returns true if scope is fully constrained, ie target, virtual device id and + * memory scope are all specified. + */ + bool is_fully_constrained() const { + return target_.defined() && virtual_device_id_ != -1 && !memory_scope_.empty(); + } + + Device ToDevice() const { + ICHECK(device_type_ != kInvalidDeviceType); + ICHECK(virtual_device_id_ != -1); + Device device; + device.device_type = device_type_; + device.device_id = virtual_device_id_; + return device; + } + + void VisitAttrs(AttrVisitor* v); + + bool SEqualReduce(const SEScopeNode* other, SEqualReducer equal) const { + // Since we memoize all constructors we can just use pointer equality. + return this == other; + } + + void SHashReduce(SHashReducer hash_reduce) const { + // Since we memoize all constructors we can just use the pointer hash + hash_reduce->SHashReduceHashedValue(std::hash()(this)); + } + + static constexpr const char* _type_key = "SEScope"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SEScopeNode, Object); + + private: + // We keep the fields private so the constructor memoization can't be upset by mutation. + DLDeviceType device_type_ = kInvalidDeviceType; + int virtual_device_id_ = -1; + Target target_{nullptr}; + String memory_scope_; // = "" + + friend class SEScope; +}; + +/*! + * \brief Managed reference class to \p SEScopeNode. + * + * \sa SEScopeNode. + */ +class SEScope : public ObjectRef { + private: + /*! + * \brief Construct an SEScope. + * \param device_type The device type for the device, or kInvalidDeviceType if unconstrained. + * If \p target is defined then must match its \p target->kind->device_type. + * \param virtual_device_id The virtual device id for the device, or -1 if unconstrained. + * \param target The target describing how to compile for the device, or null if unconstrained. + * \param memory_scope The memory scope within the device, or "" if unconstrained. + * \return The SEScope + * + * This constructor is private -- use the memoizing smart constructors below. + */ + explicit SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope); + + public: + /*! + * \brief Returns the unique \p SEScope object for \p device_type, \p virtual_device_id, \p + * target, and \p memory_scope. Any/all of these fields may be unconstrained as per their default + * values. However if \p target is defined then \p device_type must be + * \p target->kind->device_type. + */ + static SEScope MakeSEScope(DLDeviceType device_type = kInvalidDeviceType, + int virtual_device_id = -1, Target target = {}, + String memory_scope = {}); + + /*! \brief Returns the unique fully unconstrained \p SEScope. */ + static SEScope FullyUnconstrained() { return MakeSEScope(); } + + /*! + * \brief Returns the unique \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * The target and memory scope will be unconstrained. + */ + static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + ICHECK_GT(device_type, 0); + return MakeSEScope(device_type, virtual_device_id); + } + static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type), virtual_device_id); + } + static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type->value), virtual_device_id); + } + + /*! \brief Returns the unique \p SEScope for \p device. */ + TVM_DLL static SEScope ForDevice(const Device& device) { + return ForDeviceType(device.device_type, device.device_id); + } + + /*! \brief Returns the unique \p SEScope for \p device and \p target. */ + TVM_DLL static SEScope ForDeviceAndTarget(const Device& device, Target target) { + return MakeSEScope(device.device_type, device.device_id, std::move(target)); + } + + /*! \brief Returns the unique \p SEScope for \p device, \p target and \p memory_scope. */ + TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, + String memory_scope) { + return MakeSEScope(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); + } + + /*! + * \brief Returns the 'join' of \p lhs and \p rhs. The result will agree pointwise with + * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such + * join exists, ie there's disagreement on at least one constrained field. + */ + static Optional Join(const SEScope& lhs, const SEScope& rhs); + + /*! + * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any + * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. + */ + static SEScope Default(const SEScope& lhs, const SEScope& rhs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + + friend class SEScopeCache; // Private implementation helper. +}; + +/*! + * \brief Gathers the targets and scopes needed to compile a Relay module, and centralizes + * the target checking and defaulting logic. + * + * TODO(mbs): This is a temporary class to help us bridge legacy and new target/device handling + * and reduce code dup between VM (relay/backend/vm/compile.cc), graph/AOT + * (relay/backend/build_module.cc) and interpreter (relay/backend/interpreter.cc) 'executors'. It + * should probably get merged into something more sensible, ideally just THE standard compilation + * flow once we have one. + */ +struct CompilationConfig { + /*! + * \brief The legacy targets map, mapping device type to \p Targets. Does not include any + * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, + * though we want to get rid of that limitation. + * + * CAUTION: Since keys are \p Integers they are compared by object equality not integer + * value. + * + * TODO(mbs): Remove once codegen updated for new target conventions. + */ + TargetMap legacy_target_map; + + /*! \brief The optional host target. Used for 'scalar' data and code (such as shapes and shape + * functions) and residual Relay expressions and data (such as conditionals and ADTs). */ + Target optional_host_target; + + /*! + * \brief Vector of all available targets, including for primitive operators, host, and any + * default targets added for required device types. + */ + Array targets; + + /*! + * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular + * device. + */ + SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + + /*! \brief SEScope for the host. */ + SEScope host_se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If defined then in 'homogenous execution mode' and all primitives will be compiled + * for this target. This is to support legacy passes which have not been adapted to hetrogeneous + * execution. + */ + Target homogeneous_target; + + CompilationConfig() = default; + + /*! + * \brief Constructs the compilation config given the available \p Targets in the + * \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use + * 'relay.fallback_device_type' and the availability of the LLVM compilation module + * to decide on appropriatte default devices. + */ + CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, + Target optional_host_target_arg); +}; + +} // namespace tvm + +#endif // TVM_TARGET_SE_SCOPE_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64a1023158e1e..660ec63582bf2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -179,6 +179,9 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; + +using TargetMap = Map; + /*! * \brief Check and update host field of the given legacy target and target host pair. * Note that this function is for legacy target api compatibility issue only, not @@ -187,22 +190,24 @@ class Target : public ObjectRef { * \param host The pointer to a Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Target* target, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with values being Target objects + * \param target_map The pointer to a Map objects with values being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with keys being Target objects + * \param ir_modules The pointer to a Map objects with keys being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1e906cb381d8c..f26da76f07da9 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,6 +71,7 @@ riscv_cpu, hexagon, ) +from .se_scope import make_se_scope from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/se_scope.py new file mode 100644 index 0000000000000..83df5ae3448aa --- /dev/null +++ b/python/tvm/target/se_scope.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating SEScopes.""" +from . import _ffi_api + + +def make_se_scope(device, target=None, memory_scope=""): + return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 12b4f6f65b11e..6869f4d372634 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -31,6 +31,7 @@ #define TVM_IR_ATTR_FUNCTOR_H_ #include +#include #include #include @@ -105,6 +106,7 @@ class AttrFunctor { virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const SEScopeNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. @@ -139,6 +141,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(CallNode); ATTR_FUNCTOR_DISPATCH(SelectNode); + ATTR_FUNCTOR_DISPATCH(SEScopeNode); return vtable; } }; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ebd6566889dcd..486799603354c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1437,8 +1437,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto tbl = tvm::ReflectionVTable::Global(); - auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + auto attr_obj = + tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ea97bb35a09f2..1b8beabae4bc7 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -773,8 +773,6 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { printed_attr << "?"; } else if (auto str_obj = value.as()) { printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (const auto* on_device_attrs = value.as()) { - printed_attr << "device_type=" << on_device_attrs->device_type; } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { @@ -787,7 +785,7 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { } Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), true); + return PrintAttr(GetRef(op), /*meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { @@ -814,6 +812,17 @@ Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { return Doc::StrLiteral(op->value); } +Doc RelayTextPrinter::VisitAttr_(const SEScopeNode* op) { + if (show_meta_data_) { + return VisitAttrDefault_(op); + } else { + // TODO(mbs): Surely there's a better way? + std::ostringstream os; + os << GetRef(op); + return Doc::Text(os.str()); + } +} + /*! * \brief Attribute printer which prints the attributes in the call. */ diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 316d596317823..fc0a8741697a6 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -168,6 +168,7 @@ class RelayTextPrinter : public ExprFunctor, Doc VisitAttr_(const tir::IntImmNode* op) final; Doc VisitAttr_(const tir::FloatImmNode* op) final; Doc VisitAttr_(const tir::StringImmNode* op) final; + Doc VisitAttr_(const SEScopeNode* op) final; private: /*! \brief Whether to print meta data. */ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3c9c35c4f2543..7e5702296542b 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -715,7 +715,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map targets = args[1]; + TargetMap targets = args[1]; init(mod, targets); }); } else if (name == "codegen") { @@ -758,7 +758,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: - void init(void* mod, Map tmp) { + void init(void* mod, TargetMap tmp) { tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7005e94c24110..78978e192f0e4 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -43,7 +43,6 @@ Pass LabelOps(); } namespace backend { -using TargetsMap = Map; using namespace tvm::relay::transform; /*! @@ -56,7 +55,7 @@ struct BuildOutput { }; struct ExecutorCodegen { - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { for (const auto& pair : targets) { VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); @@ -307,7 +306,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + IRModule Optimize(IRModule relay_module, const TargetMap& targets, const std::unordered_map& params) { targets_ = targets; // No target_host setup it seems. @@ -444,7 +443,7 @@ class RelayBuildModule : public runtime::ModuleNode { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Update all the targets in the targets_ TargetsMap + // Update all the targets in the targets_ TargetMap CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. @@ -540,7 +539,7 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! \brief target host device */ tvm::Target target_host_; /*! \brief parameters */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index debd669126c4c..d32ded3796886 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -619,7 +619,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; + TargetMap tmp = args[1]; tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e3b7d46457add..d0401e9605f7f 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -148,7 +148,7 @@ void UpdateFunctionMetadata(Function relay_func, * \param dev_type * \return Target */ -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets); /*! * \brief Update the "main" control function's metadata diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a647aa1a3fd2f..f89a099b0d4fb 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -437,7 +437,7 @@ inline bool IsAutoSchedulerEnabled() { * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPassPrefix(const Map& targets, bool is_vm); +Array GetPassPrefix(const TargetMap& targets, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3c1cd81274fc..908d3e6fa6f74 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,7 +80,6 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; -using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -251,7 +250,7 @@ int GetFallbackDevice() { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), @@ -458,7 +457,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -534,7 +533,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } } - CCacheKey key(func, target); + tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering @@ -904,7 +903,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, + const tvm::Target& target_host) { exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -970,7 +970,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { +transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; @@ -1016,9 +1016,9 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, const Target& target_host_arg) { - TargetsMap targets = targets_arg; + TargetMap targets = targets_arg; Target target_host = target_host_arg; CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index af3c5bccbeff2..5b51d7821d78b 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; -using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Target devices. */ - TargetsMap targets_; + TargetMap targets_; /*! \brief Target host device. */ tvm::Target target_host_; /*! \brief Global shared meta data */ diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc new file mode 100644 index 0000000000000..4b95213d65e97 --- /dev/null +++ b/src/target/se_scope.cc @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/target/se_scope.cc + * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + */ +#include +#include +#include + +namespace tvm { + +/*! \brief A cache of \p SEScopes. */ +class SEScopeCache { + public: + SEScope MakeSEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + // Not the most efficient, but reducing the key to a string seems to be the simplest. + // Note this means we are effectively collapsing Targets by their str() representation. + std::ostringstream os; + os << device_type; + os << ":" << virtual_device_id; + if (target.defined()) { + os << ":'" << target->str() << "'"; + } else { + os << ":null"; + } + os << ":'" << memory_scope << "'"; + std::string key = os.str(); + auto itr = cache_.find(key); + if (itr != cache_.end()) { + return itr->second; + } + SEScope scope(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); + cache_.emplace(key, scope); + VLOG(1) << "new scope \"" << key << "\" -> " << scope; + return scope; + } + + private: + std::unordered_map cache_; +}; + +/*! \brief Thread local cache of already constructed \p SEScopes. */ +using ThreadLocalSEScopeCache = dmlc::ThreadLocalStore; + +TVM_REGISTER_NODE_TYPE(SEScopeNode); + +void SEScopeNode::VisitAttrs(AttrVisitor* v) { + int i = static_cast(device_type_); + v->Visit("device_type", &i); + device_type_ = static_cast(i); + v->Visit("virtual_device_id", &virtual_device_id_); + v->Visit("target", &target_); + v->Visit("memory_scope", &memory_scope_); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "SEScopeNode("; + if (node->is_fully_unconstrained()) { + p->stream << "?"; + } else { + bool need_sep = false; + if (node->device_type() != kInvalidDeviceType) { + p->stream << "device_type=" << node->device_type(); + need_sep = true; + } + if (node->virtual_device_id() >= 0) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "virtual_device_id=" << node->virtual_device_id(); + need_sep = true; + } + if (node->target().defined()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "target='" << node->target()->str() << "'"; + need_sep = true; + } + if (!node->memory_scope().empty()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "memory_scope='" << node->memory_scope() << "'"; + } + } + p->stream << ")"; + }); + +SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + ICHECK(!target.defined() || device_type == target->kind->device_type) + << "target '" << target->str() << "' has device type " << target->kind->device_type + << " but scope has device type " << device_type; + auto object = make_object(); + object->device_type_ = device_type; + object->virtual_device_id_ = virtual_device_id; + object->target_ = std::move(target); + object->memory_scope_ = std::move(memory_scope); + data_ = std::move(object); +} + +/* static */ +SEScope SEScope::MakeSEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + return ThreadLocalSEScopeCache ::Get()->MakeSEScope(device_type, virtual_device_id, + std::move(target), std::move(memory_scope)); +} + +/* static */ +Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType joined_device_type; + if (lhs->device_type_ != kInvalidDeviceType) { + joined_device_type = lhs->device_type_; + if (rhs->device_type_ != kInvalidDeviceType && lhs->device_type_ != rhs->device_type_) { + return {}; + } + } else { + joined_device_type = rhs->device_type_; + } + int joined_virtual_device_id; + if (lhs->virtual_device_id_ >= 0) { + joined_virtual_device_id = lhs->virtual_device_id_; + if (rhs->virtual_device_id_ >= 0 && lhs->virtual_device_id_ != rhs->virtual_device_id_) { + return {}; + } + } else { + joined_virtual_device_id = rhs->virtual_device_id_; + } + Target joined_target; + if (lhs->target_.defined()) { + joined_target = lhs->target_; + if (rhs->target_.defined() && lhs->target_ != rhs->target_) { + return {}; + } + } else { + joined_target = rhs->target_; + } + String joined_memory_scope; + if (!lhs->memory_scope_.empty()) { + joined_memory_scope = lhs->memory_scope_; + if (!rhs->memory_scope_.empty() && lhs->memory_scope_ != rhs->memory_scope_) { + return {}; + } + } else { + joined_memory_scope = rhs->memory_scope_; + } + return MakeSEScope(joined_device_type, joined_virtual_device_id, joined_target, + joined_memory_scope); +} + +/* static */ +SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType defaulted_device_type; + if (lhs->device_type_ != kInvalidDeviceType) { + defaulted_device_type = lhs->device_type_; + } else { + defaulted_device_type = rhs->device_type_; + } + int defaulted_virtual_device_id; + if (lhs->virtual_device_id_ >= 0) { + defaulted_virtual_device_id = lhs->virtual_device_id_; + } else { + defaulted_virtual_device_id = rhs->virtual_device_id_; + } + Target defaulted_target; + if (lhs->target_.defined()) { + defaulted_target = lhs->target_; + } else { + // We can only default to the rhs's target if it is consistent with the device type + if (rhs->target_.defined() && rhs->target_->kind->device_type == defaulted_device_type) { + defaulted_target = rhs->target_; + } + // else: leave as null + } + String defaulted_memory_scope; + if (!lhs->memory_scope_.empty()) { + defaulted_memory_scope = lhs->memory_scope_; + } else { + defaulted_memory_scope = rhs->memory_scope_; + } + return MakeSEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); +} + +namespace { +/*! + * \brief Returns a freshly constructed \p Target to represent \p device_type. + */ +Target MakeDefaultTarget(DLDeviceType device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") { + if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { + // LLVM is available. + return Target("llvm"); + } else { + // LLVM is not available. + return Target("stackvm"); + } + } else { + return Target(name); + } +} + +/*! + * \brief Return the \p Target to use for \p device_type, possibly by stealing the \p host_target, + * or by creating a fresh target. + */ +Target FindOrAddDefault(Array* targets, const Target& optional_host_target, + DLDeviceType device_type) { + auto itr = std::find_if(targets->begin(), targets->end(), [device_type](const Target& target) { + return target->kind->device_type == device_type; + }); + if (itr == targets->end()) { + if (optional_host_target.defined() && optional_host_target->kind->device_type == device_type) { + LOG(INFO) << "Using the given host target '" << optional_host_target->str() + << "' for device type " << device_type; + targets->push_back(optional_host_target); + return optional_host_target; + } else { + Target target = MakeDefaultTarget(device_type); + LOG(WARNING) << "No target has been given for the device type " << device_type + << " in the targets list. Creating a default target '" << target->str() + << "' for that device"; + targets->push_back(target); + return target; + } + } else { + return *itr; + } +} + +/*! + * \brief Returns the default \p SEScope for primitives and the \p SEScope for the host + * given vector of available \p targets. If necessary, add new \p Targets to \p targets + * to match the required devices. + */ +std::pair EstablishDefaultSEScopes(const transform::PassContext& pass_ctx, + Array* targets, + const Target& optional_host_target) { + // + // Gather the hints as to what our default device type for primitives should be. + // + DLDeviceType default_primitive_device_type; + Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + const int64_t v = opt_fallback_dev.value()->value; + if (v <= 0) { + LOG(FATAL) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " + << v; + default_primitive_device_type = kDLCPU; + } else { + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + } else if (targets->size() == 1) { + // In the homogeneous case there's no free choice. + default_primitive_device_type = static_cast(targets->front()->kind->device_type); + LOG(INFO) << "Using the unique target '" << targets->front()->str() << "' of device type " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } else { + default_primitive_device_type = kDLCPU; + LOG(WARNING) << "Using " << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + + // + // Gather the hints as to what our default device type for the 'host' should be. + // + DLDeviceType host_device_type; + if (optional_host_target.defined()) { + host_device_type = static_cast(optional_host_target->kind->device_type); + if (host_device_type != kDLCPU) { + LOG(WARNING) << "Using the host target '" << optional_host_target->str() + << "' of non-CPU device type " << host_device_type + << " for all host operations and data"; + } else { + LOG(INFO) << "Using the host target '" << optional_host_target->str() << "' of device type " + << host_device_type << " for all host operations and data"; + } + } else { + host_device_type = kDLCPU; + LOG(INFO) << "Using " << host_device_type + << " as the device type for all host operations and data"; + } + + // + // Now establish default targets + // + Target default_primitive_target = + FindOrAddDefault(targets, optional_host_target, default_primitive_device_type); + Target actual_host_target = + optional_host_target.defined() + ? optional_host_target + : FindOrAddDefault(targets, optional_host_target, host_device_type); + + return {SEScope::MakeSEScope(default_primitive_device_type, + /*virtual_device_id=*/0, default_primitive_target), + SEScope::MakeSEScope(host_device_type, + /*virtual_device_id=*/0, actual_host_target)}; +} +} // namespace + +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + TargetMap legacy_target_map_arg, + Target optional_host_target_arg) + : legacy_target_map(std::move(legacy_target_map_arg)), + optional_host_target(std::move(optional_host_target_arg)) { + VLOG_CONTEXT << "CompilationConfig"; + for (const auto& pair : legacy_target_map) { + VLOG(0) << "Available target " << pair.first << " = '" << pair.second->str() << "'"; + } + if (optional_host_target.defined()) { + VLOG(0) << "Available host target '" << optional_host_target->str() << "'"; + } + + // Legacy: Host & primitive targets need to be consistent. + CheckAndUpdateHostConsistency(&legacy_target_map, &optional_host_target); + + // Gather the primitive targets as an ordinary vector. + for (const auto& pair : legacy_target_map) { + targets.push_back(pair.second); + } + + // Complete the targets vector and establish default scopes. After this targets_ will contain + // the definitive list of all required targets, both for host and primitives. + auto pair = EstablishDefaultSEScopes(pass_ctx, &targets, optional_host_target); + default_primitive_se_scope = pair.first; + host_se_scope = pair.second; + + ICHECK(default_primitive_se_scope->target().defined()); + ICHECK(host_se_scope->target().defined()); + ICHECK_GT(targets.size(), 0U); + + // If we added a target to targets_ for the default primitive scope then we need to do the same in + // the legacy target map. Note that we don't do the same for the host since the legacy map + // is only supposed to track the targets for primitives. I think. Also note that TargetMap is + // indexed by the *object identity* of the Integers for the device types so conveys nothing + // beyond just vector of targets. + auto itr = std::find_if(legacy_target_map.begin(), legacy_target_map.end(), + [this](const std::pair& pair) { + return pair.second->kind->device_type == + default_primitive_se_scope->device_type(); + }); + if (itr == legacy_target_map.end()) { + legacy_target_map.Set(static_cast(default_primitive_se_scope->device_type()), + default_primitive_se_scope->target()); + } + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. + homogeneous_target = + legacy_target_map.size() == 1 ? (*legacy_target_map.begin()).second : Target(); + + for (const auto& target : targets) { + VLOG(0) << "Established build target " << target->kind->device_type << " = '" << target->str() + << "'"; + } + VLOG(0) << "Established default primitive SEScope " << default_primitive_se_scope; + VLOG(0) << "Established host SEScope " << host_se_scope; +} + +TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") + .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); + +} // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d79..60c6303412e46 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -74,7 +74,7 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { +void CheckAndUpdateHostConsistency(TargetMap* targets, Target* host) { Map new_targets; for (auto& it : *targets) { auto target = it.second; diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc new file mode 100644 index 0000000000000..af56be2f909a8 --- /dev/null +++ b/tests/cpp/target/se_scope_test.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace target { +namespace { + +TEST(SEScope, MemoizedConstructors) { + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + SEScope se_scope_a = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"); + SEScope se_scope_b = SEScope::MakeSEScope(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"), se_scope_a); + EXPECT_EQ(SEScope::MakeSEScope(kDLCPU, 1, target_b, "global"), se_scope_b); + EXPECT_NE(SEScope::MakeSEScope(kDLCUDA, 2, target_a, "local"), se_scope_a); + EXPECT_NE(SEScope::MakeSEScope(kDLCPU, 3, target_b, "local"), se_scope_a); + EXPECT_NE(SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"), se_scope_a); +} + +TEST(SEScope, Join_Defined) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, 3); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, -1, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_EQ(actual.value(), expected); +} + +TEST(SEScope, Join_Undefined) { + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, 3); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, 4); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); +} + +TEST(SEScope, Default) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, -1, Target(), "global"); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"); + SEScope actual = SEScope::Default(lhs, rhs); + SEScope expected = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_EQ(actual, expected); +} + +TEST(CompilationConfig, Constructor) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + + Target cuda_target = Target("nvidia/tesla-p40"); + Target default_cpu_target = Target("llvm"); + + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/Target()); + + EXPECT_EQ(config.legacy_target_map.size(), 1); + EXPECT_EQ((*config.legacy_target_map.begin()).second->str(), cuda_target->str()); + EXPECT_FALSE(config.optional_host_target.defined()); + EXPECT_EQ(config.targets.size(), 2); + EXPECT_EQ(config.targets[0]->str(), cuda_target->str()); + EXPECT_EQ(config.targets[1]->str(), default_cpu_target->str()); + EXPECT_EQ(config.default_primitive_se_scope->device_type(), kDLCUDA); + EXPECT_EQ(config.default_primitive_se_scope->target()->str(), cuda_target->str()); + EXPECT_EQ(config.host_se_scope->device_type(), kDLCPU); + EXPECT_EQ(config.host_se_scope->target()->str(), default_cpu_target->str()); + EXPECT_EQ(config.homogeneous_target->str(), cuda_target->str()); +} + +} // namespace +} // namespace target +} // namespace tvm diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_se_scope.py new file mode 100644 index 0000000000000..0a9384fa9c046 --- /dev/null +++ b/tests/python/target/test_se_scope.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm + + +def test_make_se_scope_for_device(): + se_scope = tvm.target.make_se_scope(tvm.device("cuda")) + assert se_scope.device_type == 2 + # ie kDLCUDA + assert se_scope.virtual_device_id == 0 + assert se_scope.target is None + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_and_target(): + target = tvm.target.Target("cuda") + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_target_and_memory_scope(): + target = tvm.target.Target("cuda") + scope = "local" + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == scope + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))