diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 715c96eb6ea52..f6c15f9590df0 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -489,7 +489,7 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; - // applies two stratgies to lookup + // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h new file mode 100644 index 0000000000000..5573fd1f22ea3 --- /dev/null +++ b/include/tvm/target/se_scope.h @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.h + * \brief A compile time representation for a Storage or Execution Scope. + */ + +#ifndef TVM_TARGET_SE_SCOPE_H_ +#define TVM_TARGET_SE_SCOPE_H_ + +#include +#include + +#include +#include + +namespace tvm { + +class SEScope; + +/*! + * \brief Describes at compile time where data is to be stored down to the device and memory + * scope level, or where execution is to take place, down to the device level. It is a quadruple of: + * - A \p device_type (\p DLDeviceType). + * - An uninterpreted \p virtual_device_id (\p int) distinguishing the intended device from all + * other devices (either of the same \p device_type, or across all availabel devices in the + * system). The \p virtual_device_id may be left as 0 if not significant. It is up to downstream + * compilation passes and/or the runtime to map a \p virtual_device_id to an actual physical + * device if required. In particular the \p virtual_device_id need not correspond exactly to + * any runtime \p Device's \p device_id. + * - A \p target (\p Target) describing how to compile code for the intended device. The + * \p target->kind->device_type must match the above \p device_type. + * - A \p memory_scope (currently just \p String) describing which memory area is to be used to + * hold data. The area should be reachable from the device but need not be 'on' the device, + * see below. (We're using a \p String for now but would prefer a more structured representation.) + * + * All of these fields may be 'unconstrained', signaling that device planning is free to choose + * a value consistent with the whole program. However if a \p target is given then the + * \p device_type must equal \p target->kind->device_type. + * + * Note that currently we assume if a function returns its result on a particular device + * then the function body is also executed on that device. See the overview comment in + * src/relay/transforms/device_planner.cc for more details. + * + * By 'data' we include both tensors and additional supporting datastructures such as shapes, + * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside + * on a 'CPU'-like device with good support for scalars. + * + * By 'execution' we include both (fused) primitive operators, and all the Relay expressions + * surrounding them which coordinates data and control flow. Again, typically non-primitive + * operators must be executed on a 'CPU'-like device with good support for control flow. + * + * Targets vs Devices + * ------------------ + * Generally \p Targets (a compile-time only datastructue) describe compiler options for a specific + * microarchitecture and toolchain, while \p Devices (a runtime datastructure alsa available at + * compile time) describe a physical device on the target system. Obviously the target must agree + * with the device's microarchitecture, but we otherwise don't impose any constraints between them: + * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf + * out of a particular primitive. + * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. + * + * Traditionally TVM assumes at most one \p Target per \p DLviceType. We are moving away from that + * assumption. + * + * Memory scopes and devices + * ------------------------- + * Multi-device systems can have complex memory hierarchies. For example + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCPU, 1, "llvm", "global") + * \endcode + * could denote: + * - The same memory area accessible from two separate CPUs without any CPU affinity; + * - Distinct memory areas in a NUMA architecture for which cross-device access is handled + * by the memory system; + * - Outright distinct memory areas, where one device cannot directly address the memory of + * another. + * + * Similarly: + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCUDA, 0, "cuda", "host") + * \endcode + * could denote the same memory area, but with very different access costs. + * + * We don't currently try to build any of this system-level understanding into \p SEScope. Device + * planning will simply insert "device_copy" operators wherever \p SEScopes are not exactly + * pointwise equal, and we leave it to downstream compilation to elide unnecessary copies. We + * may revisit this in the future. + * + * Object identity + * --------------- + * \p SEScopes can only be constructed by the memoizing helpers. This means code can assume + * \code + * se_scope1 != se_scope2 => se_scope1 and se_scope2 differ pointwise + * \endcode + * This simplifies the device planner which needs to solve equality constraints between \p SEScopes. + * + * Joining and Defaulting + * ---------------------- + * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees + * with both join arguments. Eg: + * \code + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) + * => (kDLCPU, 3, "llvm", "global") + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "local)) + * => null (no join possible) + * \endcode + * + * Related to 'join' is 'default', which only takes constrained fields from the rhs when the + * lhs is unconstrained: + * \code + * Default(kDLCPU, -1, "llvm", "local"), (kDLCPU, 3, null, "global")) + * => (kDLCPU, 3, "llvm", "local") + * \endcode + * + * These operations are needed during device planning. + * + */ +class SEScopeNode : public Object { + public: + /*! + * \brief The \p DLDeviceType of the device. If \p target is known then this will be equal to + * \p target->kind->device_type. If \p target is null then the target is to be determined by + * a later pass. + * + * This is needed to support the legacy "on_device" and "device_copy" calls which only allow + * a \p DLDeviceTypes (as an integer) to be given. + * + * kInvalidDeviceType denotes unconstrained. + */ + DLDeviceType device_type() const { return device_type_; } + + /*! + * \brief The 'virtual' device identifier for the device. This must be resolved to a physical + * device identifier either during compilation or at runtime. + * + * -1 denotes unconstrained. May be 0 if not significant. + */ + int virtual_device_id() const { return virtual_device_id_; } + + /*! + * \brief The \p Target describing how to compile for the device. + * + * Null denotes unconstrained (though if device_type is known then only a target of that + * type is allowed). + */ + const Target& target() const { return target_; } + + /*! + * \brief The scope of memory within the device. + * + * Empty denotes unconstrained. + */ + // TODO(mbs): We are using String as a stand-in pending a more structured representation, such + // as runtime::StorageScope or a memory pool. + const String& memory_scope() const { return memory_scope_; } + + /*! + * \brief Returns true if scope is fully unconstrained, ie no target/device type, virtual device + * id or memory scope is specified. + */ + bool is_fully_unconstrained() const { + return !target_.defined() && device_type_ == kInvalidDeviceType && virtual_device_id_ == -1 && + memory_scope_.empty(); + } + + /*! + * \brief Returns true if scope is fully constrained, ie target, virtual device id and + * memory scope are all specified. + */ + bool is_fully_constrained() const { + return target_.defined() && virtual_device_id_ != -1 && !memory_scope_.empty(); + } + + Device ToDevice() const { + ICHECK(device_type_ != kInvalidDeviceType); + ICHECK(virtual_device_id_ != -1); + Device device; + device.device_type = device_type_; + device.device_id = virtual_device_id_; + return device; + } + + void VisitAttrs(AttrVisitor* v); + + bool SEqualReduce(const SEScopeNode* other, SEqualReducer equal) const { + // Since we memoize all constructors we can just use pointer equality. + return this == other; + } + + void SHashReduce(SHashReducer hash_reduce) const { + // Since we memoize all constructors we can just use the pointer hash + hash_reduce->SHashReduceHashedValue(std::hash()(this)); + } + + static constexpr const char* _type_key = "SEScope"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SEScopeNode, Object); + + private: + // We keep the fields private so the constructor memoization can't be upset by mutation. + DLDeviceType device_type_ = kInvalidDeviceType; + int virtual_device_id_ = -1; + Target target_{nullptr}; + String memory_scope_; // = "" + + friend class SEScope; +}; + +/*! + * \brief Managed reference class to \p SEScopeNode. + * + * \sa SEScopeNode. + */ +class SEScope : public ObjectRef { + private: + /*! + * \brief Construct an SEScope. + * \param device_type The device type for the device, or kInvalidDeviceType if unconstrained. + * If \p target is defined then must match its \p target->kind->device_type. + * \param virtual_device_id The virtual device id for the device, or -1 if unconstrained. + * \param target The target describing how to compile for the device, or null if unconstrained. + * \param memory_scope The memory scope within the device, or "" if unconstrained. + * \return The SEScope + * + * This constructor is private -- use the memoizing smart constructors below. + */ + explicit SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope); + + public: + /*! + * \brief Returns the unique \p SEScope object for \p device_type, \p virtual_device_id, \p + * target, and \p memory_scope. Any/all of these fields may be unconstrained as per their default + * values. However if \p target is defined then \p device_type must be + * \p target->kind->device_type. + */ + static SEScope MakeSEScope(DLDeviceType device_type = kInvalidDeviceType, + int virtual_device_id = -1, Target target = {}, + String memory_scope = {}); + + /*! \brief Returns the unique fully unconstrained \p SEScope. */ + static SEScope FullyUnconstrained() { return MakeSEScope(); } + + /*! + * \brief Returns the unique \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * The target and memory scope will be unconstrained. + */ + static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + ICHECK_GT(device_type, 0); + return MakeSEScope(device_type, virtual_device_id); + } + static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type), virtual_device_id); + } + static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type->value), virtual_device_id); + } + + /*! \brief Returns the unique \p SEScope for \p device. */ + TVM_DLL static SEScope ForDevice(const Device& device) { + return ForDeviceType(device.device_type, device.device_id); + } + + /*! \brief Returns the unique \p SEScope for \p device and \p target. */ + TVM_DLL static SEScope ForDeviceAndTarget(const Device& device, Target target) { + return MakeSEScope(device.device_type, device.device_id, std::move(target)); + } + + /*! \brief Returns the unique \p SEScope for \p device, \p target and \p memory_scope. */ + TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, + String memory_scope) { + return MakeSEScope(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); + } + + /*! + * \brief Returns the 'join' of \p lhs and \p rhs. The result will agree pointwise with + * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such + * join exists, ie there's disagreement on at least one constrained field. + */ + static Optional Join(const SEScope& lhs, const SEScope& rhs); + + /*! + * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any + * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. + */ + static SEScope Default(const SEScope& lhs, const SEScope& rhs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + + friend class SEScopeCache; // Private implementation helper. +}; + +/*! + * \brief Gathers the targets and scopes needed to compile a Relay module, and centralizes + * the target checking and defaulting logic. + * + * TODO(mbs): This is a temporary class to help us bridge legacy and new target/device handling + * and reduce code dup between VM (relay/backend/vm/compile.cc), graph/AOT + * (relay/backend/build_module.cc) and interpreter (relay/backend/interpreter.cc) 'executors'. It + * should probably get merged into something more sensible, ideally just THE standard compilation + * flow once we have one. + */ +struct CompilationConfig { + /*! + * \brief The legacy targets map, mapping device type to \p Targets. Does not include any + * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, + * though we want to get rid of that limitation. + * + * CAUTION: Since keys are \p Integers they are compared by object equality not integer + * value. + * + * TODO(mbs): Remove once codegen updated for new target conventions. + */ + TargetMap legacy_target_map; + + /*! \brief The optional host target. Used for 'scalar' data and code (such as shapes and shape + * functions) and residual Relay expressions and data (such as conditionals and ADTs). */ + Target optional_host_target; + + /*! + * \brief Vector of all available targets, including for primitive operators, host, and any + * default targets added for required device types. + */ + Array targets; + + /*! + * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular + * device. + */ + SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + + /*! \brief SEScope for the host. */ + SEScope host_se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If defined then in 'homogenous execution mode' and all primitives will be compiled + * for this target. This is to support legacy passes which have not been adapted to hetrogeneous + * execution. + */ + Target homogeneous_target; + + CompilationConfig() = default; + + /*! + * \brief Constructs the compilation config given the available \p Targets in the + * \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use + * 'relay.fallback_device_type' and the availability of the LLVM compilation module + * to decide on appropriatte default devices. + */ + CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, + Target optional_host_target_arg); +}; + +} // namespace tvm + +#endif // TVM_TARGET_SE_SCOPE_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64a1023158e1e..660ec63582bf2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -179,6 +179,9 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; + +using TargetMap = Map; + /*! * \brief Check and update host field of the given legacy target and target host pair. * Note that this function is for legacy target api compatibility issue only, not @@ -187,22 +190,24 @@ class Target : public ObjectRef { * \param host The pointer to a Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Target* target, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with values being Target objects + * \param target_map The pointer to a Map objects with values being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with keys being Target objects + * \param ir_modules The pointer to a Map objects with keys being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1e906cb381d8c..f26da76f07da9 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,6 +71,7 @@ riscv_cpu, hexagon, ) +from .se_scope import make_se_scope from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/se_scope.py new file mode 100644 index 0000000000000..83df5ae3448aa --- /dev/null +++ b/python/tvm/target/se_scope.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating SEScopes.""" +from . import _ffi_api + + +def make_se_scope(device, target=None, memory_scope=""): + return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 12b4f6f65b11e..6869f4d372634 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -31,6 +31,7 @@ #define TVM_IR_ATTR_FUNCTOR_H_ #include +#include #include #include @@ -105,6 +106,7 @@ class AttrFunctor { virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const SEScopeNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. @@ -139,6 +141,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(CallNode); ATTR_FUNCTOR_DISPATCH(SelectNode); + ATTR_FUNCTOR_DISPATCH(SEScopeNode); return vtable; } }; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ebd6566889dcd..486799603354c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1437,8 +1437,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto tbl = tvm::ReflectionVTable::Global(); - auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + auto attr_obj = + tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ea97bb35a09f2..1b8beabae4bc7 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -773,8 +773,6 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { printed_attr << "?"; } else if (auto str_obj = value.as()) { printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (const auto* on_device_attrs = value.as()) { - printed_attr << "device_type=" << on_device_attrs->device_type; } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { @@ -787,7 +785,7 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { } Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), true); + return PrintAttr(GetRef(op), /*meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { @@ -814,6 +812,17 @@ Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { return Doc::StrLiteral(op->value); } +Doc RelayTextPrinter::VisitAttr_(const SEScopeNode* op) { + if (show_meta_data_) { + return VisitAttrDefault_(op); + } else { + // TODO(mbs): Surely there's a better way? + std::ostringstream os; + os << GetRef(op); + return Doc::Text(os.str()); + } +} + /*! * \brief Attribute printer which prints the attributes in the call. */ diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 316d596317823..fc0a8741697a6 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -168,6 +168,7 @@ class RelayTextPrinter : public ExprFunctor, Doc VisitAttr_(const tir::IntImmNode* op) final; Doc VisitAttr_(const tir::FloatImmNode* op) final; Doc VisitAttr_(const tir::StringImmNode* op) final; + Doc VisitAttr_(const SEScopeNode* op) final; private: /*! \brief Whether to print meta data. */ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3c9c35c4f2543..7e5702296542b 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -715,7 +715,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map targets = args[1]; + TargetMap targets = args[1]; init(mod, targets); }); } else if (name == "codegen") { @@ -758,7 +758,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: - void init(void* mod, Map tmp) { + void init(void* mod, TargetMap tmp) { tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7005e94c24110..78978e192f0e4 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -43,7 +43,6 @@ Pass LabelOps(); } namespace backend { -using TargetsMap = Map; using namespace tvm::relay::transform; /*! @@ -56,7 +55,7 @@ struct BuildOutput { }; struct ExecutorCodegen { - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { for (const auto& pair : targets) { VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); @@ -307,7 +306,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + IRModule Optimize(IRModule relay_module, const TargetMap& targets, const std::unordered_map& params) { targets_ = targets; // No target_host setup it seems. @@ -444,7 +443,7 @@ class RelayBuildModule : public runtime::ModuleNode { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Update all the targets in the targets_ TargetsMap + // Update all the targets in the targets_ TargetMap CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. @@ -540,7 +539,7 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! \brief target host device */ tvm::Target target_host_; /*! \brief parameters */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index debd669126c4c..d32ded3796886 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -619,7 +619,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; + TargetMap tmp = args[1]; tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e3b7d46457add..d0401e9605f7f 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -148,7 +148,7 @@ void UpdateFunctionMetadata(Function relay_func, * \param dev_type * \return Target */ -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets); /*! * \brief Update the "main" control function's metadata diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a647aa1a3fd2f..f89a099b0d4fb 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -437,7 +437,7 @@ inline bool IsAutoSchedulerEnabled() { * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPassPrefix(const Map& targets, bool is_vm); +Array GetPassPrefix(const TargetMap& targets, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3c1cd81274fc..908d3e6fa6f74 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,7 +80,6 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; -using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -251,7 +250,7 @@ int GetFallbackDevice() { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), @@ -458,7 +457,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -534,7 +533,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } } - CCacheKey key(func, target); + tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering @@ -904,7 +903,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, + const tvm::Target& target_host) { exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -970,7 +970,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { +transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; @@ -1016,9 +1016,9 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, const Target& target_host_arg) { - TargetsMap targets = targets_arg; + TargetMap targets = targets_arg; Target target_host = target_host_arg; CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index af3c5bccbeff2..5b51d7821d78b 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; -using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Target devices. */ - TargetsMap targets_; + TargetMap targets_; /*! \brief Target host device. */ tvm::Target target_host_; /*! \brief Global shared meta data */ diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc new file mode 100644 index 0000000000000..4b95213d65e97 --- /dev/null +++ b/src/target/se_scope.cc @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/target/se_scope.cc + * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + */ +#include +#include +#include + +namespace tvm { + +/*! \brief A cache of \p SEScopes. */ +class SEScopeCache { + public: + SEScope MakeSEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + // Not the most efficient, but reducing the key to a string seems to be the simplest. + // Note this means we are effectively collapsing Targets by their str() representation. + std::ostringstream os; + os << device_type; + os << ":" << virtual_device_id; + if (target.defined()) { + os << ":'" << target->str() << "'"; + } else { + os << ":null"; + } + os << ":'" << memory_scope << "'"; + std::string key = os.str(); + auto itr = cache_.find(key); + if (itr != cache_.end()) { + return itr->second; + } + SEScope scope(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); + cache_.emplace(key, scope); + VLOG(1) << "new scope \"" << key << "\" -> " << scope; + return scope; + } + + private: + std::unordered_map cache_; +}; + +/*! \brief Thread local cache of already constructed \p SEScopes. */ +using ThreadLocalSEScopeCache = dmlc::ThreadLocalStore; + +TVM_REGISTER_NODE_TYPE(SEScopeNode); + +void SEScopeNode::VisitAttrs(AttrVisitor* v) { + int i = static_cast(device_type_); + v->Visit("device_type", &i); + device_type_ = static_cast(i); + v->Visit("virtual_device_id", &virtual_device_id_); + v->Visit("target", &target_); + v->Visit("memory_scope", &memory_scope_); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "SEScopeNode("; + if (node->is_fully_unconstrained()) { + p->stream << "?"; + } else { + bool need_sep = false; + if (node->device_type() != kInvalidDeviceType) { + p->stream << "device_type=" << node->device_type(); + need_sep = true; + } + if (node->virtual_device_id() >= 0) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "virtual_device_id=" << node->virtual_device_id(); + need_sep = true; + } + if (node->target().defined()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "target='" << node->target()->str() << "'"; + need_sep = true; + } + if (!node->memory_scope().empty()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "memory_scope='" << node->memory_scope() << "'"; + } + } + p->stream << ")"; + }); + +SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + ICHECK(!target.defined() || device_type == target->kind->device_type) + << "target '" << target->str() << "' has device type " << target->kind->device_type + << " but scope has device type " << device_type; + auto object = make_object(); + object->device_type_ = device_type; + object->virtual_device_id_ = virtual_device_id; + object->target_ = std::move(target); + object->memory_scope_ = std::move(memory_scope); + data_ = std::move(object); +} + +/* static */ +SEScope SEScope::MakeSEScope(DLDeviceType device_type, int virtual_device_id, Target target, + String memory_scope) { + return ThreadLocalSEScopeCache ::Get()->MakeSEScope(device_type, virtual_device_id, + std::move(target), std::move(memory_scope)); +} + +/* static */ +Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType joined_device_type; + if (lhs->device_type_ != kInvalidDeviceType) { + joined_device_type = lhs->device_type_; + if (rhs->device_type_ != kInvalidDeviceType && lhs->device_type_ != rhs->device_type_) { + return {}; + } + } else { + joined_device_type = rhs->device_type_; + } + int joined_virtual_device_id; + if (lhs->virtual_device_id_ >= 0) { + joined_virtual_device_id = lhs->virtual_device_id_; + if (rhs->virtual_device_id_ >= 0 && lhs->virtual_device_id_ != rhs->virtual_device_id_) { + return {}; + } + } else { + joined_virtual_device_id = rhs->virtual_device_id_; + } + Target joined_target; + if (lhs->target_.defined()) { + joined_target = lhs->target_; + if (rhs->target_.defined() && lhs->target_ != rhs->target_) { + return {}; + } + } else { + joined_target = rhs->target_; + } + String joined_memory_scope; + if (!lhs->memory_scope_.empty()) { + joined_memory_scope = lhs->memory_scope_; + if (!rhs->memory_scope_.empty() && lhs->memory_scope_ != rhs->memory_scope_) { + return {}; + } + } else { + joined_memory_scope = rhs->memory_scope_; + } + return MakeSEScope(joined_device_type, joined_virtual_device_id, joined_target, + joined_memory_scope); +} + +/* static */ +SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType defaulted_device_type; + if (lhs->device_type_ != kInvalidDeviceType) { + defaulted_device_type = lhs->device_type_; + } else { + defaulted_device_type = rhs->device_type_; + } + int defaulted_virtual_device_id; + if (lhs->virtual_device_id_ >= 0) { + defaulted_virtual_device_id = lhs->virtual_device_id_; + } else { + defaulted_virtual_device_id = rhs->virtual_device_id_; + } + Target defaulted_target; + if (lhs->target_.defined()) { + defaulted_target = lhs->target_; + } else { + // We can only default to the rhs's target if it is consistent with the device type + if (rhs->target_.defined() && rhs->target_->kind->device_type == defaulted_device_type) { + defaulted_target = rhs->target_; + } + // else: leave as null + } + String defaulted_memory_scope; + if (!lhs->memory_scope_.empty()) { + defaulted_memory_scope = lhs->memory_scope_; + } else { + defaulted_memory_scope = rhs->memory_scope_; + } + return MakeSEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); +} + +namespace { +/*! + * \brief Returns a freshly constructed \p Target to represent \p device_type. + */ +Target MakeDefaultTarget(DLDeviceType device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") { + if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { + // LLVM is available. + return Target("llvm"); + } else { + // LLVM is not available. + return Target("stackvm"); + } + } else { + return Target(name); + } +} + +/*! + * \brief Return the \p Target to use for \p device_type, possibly by stealing the \p host_target, + * or by creating a fresh target. + */ +Target FindOrAddDefault(Array* targets, const Target& optional_host_target, + DLDeviceType device_type) { + auto itr = std::find_if(targets->begin(), targets->end(), [device_type](const Target& target) { + return target->kind->device_type == device_type; + }); + if (itr == targets->end()) { + if (optional_host_target.defined() && optional_host_target->kind->device_type == device_type) { + LOG(INFO) << "Using the given host target '" << optional_host_target->str() + << "' for device type " << device_type; + targets->push_back(optional_host_target); + return optional_host_target; + } else { + Target target = MakeDefaultTarget(device_type); + LOG(WARNING) << "No target has been given for the device type " << device_type + << " in the targets list. Creating a default target '" << target->str() + << "' for that device"; + targets->push_back(target); + return target; + } + } else { + return *itr; + } +} + +/*! + * \brief Returns the default \p SEScope for primitives and the \p SEScope for the host + * given vector of available \p targets. If necessary, add new \p Targets to \p targets + * to match the required devices. + */ +std::pair EstablishDefaultSEScopes(const transform::PassContext& pass_ctx, + Array* targets, + const Target& optional_host_target) { + // + // Gather the hints as to what our default device type for primitives should be. + // + DLDeviceType default_primitive_device_type; + Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + const int64_t v = opt_fallback_dev.value()->value; + if (v <= 0) { + LOG(FATAL) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " + << v; + default_primitive_device_type = kDLCPU; + } else { + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + } else if (targets->size() == 1) { + // In the homogeneous case there's no free choice. + default_primitive_device_type = static_cast(targets->front()->kind->device_type); + LOG(INFO) << "Using the unique target '" << targets->front()->str() << "' of device type " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } else { + default_primitive_device_type = kDLCPU; + LOG(WARNING) << "Using " << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + + // + // Gather the hints as to what our default device type for the 'host' should be. + // + DLDeviceType host_device_type; + if (optional_host_target.defined()) { + host_device_type = static_cast(optional_host_target->kind->device_type); + if (host_device_type != kDLCPU) { + LOG(WARNING) << "Using the host target '" << optional_host_target->str() + << "' of non-CPU device type " << host_device_type + << " for all host operations and data"; + } else { + LOG(INFO) << "Using the host target '" << optional_host_target->str() << "' of device type " + << host_device_type << " for all host operations and data"; + } + } else { + host_device_type = kDLCPU; + LOG(INFO) << "Using " << host_device_type + << " as the device type for all host operations and data"; + } + + // + // Now establish default targets + // + Target default_primitive_target = + FindOrAddDefault(targets, optional_host_target, default_primitive_device_type); + Target actual_host_target = + optional_host_target.defined() + ? optional_host_target + : FindOrAddDefault(targets, optional_host_target, host_device_type); + + return {SEScope::MakeSEScope(default_primitive_device_type, + /*virtual_device_id=*/0, default_primitive_target), + SEScope::MakeSEScope(host_device_type, + /*virtual_device_id=*/0, actual_host_target)}; +} +} // namespace + +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + TargetMap legacy_target_map_arg, + Target optional_host_target_arg) + : legacy_target_map(std::move(legacy_target_map_arg)), + optional_host_target(std::move(optional_host_target_arg)) { + VLOG_CONTEXT << "CompilationConfig"; + for (const auto& pair : legacy_target_map) { + VLOG(0) << "Available target " << pair.first << " = '" << pair.second->str() << "'"; + } + if (optional_host_target.defined()) { + VLOG(0) << "Available host target '" << optional_host_target->str() << "'"; + } + + // Legacy: Host & primitive targets need to be consistent. + CheckAndUpdateHostConsistency(&legacy_target_map, &optional_host_target); + + // Gather the primitive targets as an ordinary vector. + for (const auto& pair : legacy_target_map) { + targets.push_back(pair.second); + } + + // Complete the targets vector and establish default scopes. After this targets_ will contain + // the definitive list of all required targets, both for host and primitives. + auto pair = EstablishDefaultSEScopes(pass_ctx, &targets, optional_host_target); + default_primitive_se_scope = pair.first; + host_se_scope = pair.second; + + ICHECK(default_primitive_se_scope->target().defined()); + ICHECK(host_se_scope->target().defined()); + ICHECK_GT(targets.size(), 0U); + + // If we added a target to targets_ for the default primitive scope then we need to do the same in + // the legacy target map. Note that we don't do the same for the host since the legacy map + // is only supposed to track the targets for primitives. I think. Also note that TargetMap is + // indexed by the *object identity* of the Integers for the device types so conveys nothing + // beyond just vector of targets. + auto itr = std::find_if(legacy_target_map.begin(), legacy_target_map.end(), + [this](const std::pair& pair) { + return pair.second->kind->device_type == + default_primitive_se_scope->device_type(); + }); + if (itr == legacy_target_map.end()) { + legacy_target_map.Set(static_cast(default_primitive_se_scope->device_type()), + default_primitive_se_scope->target()); + } + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. + homogeneous_target = + legacy_target_map.size() == 1 ? (*legacy_target_map.begin()).second : Target(); + + for (const auto& target : targets) { + VLOG(0) << "Established build target " << target->kind->device_type << " = '" << target->str() + << "'"; + } + VLOG(0) << "Established default primitive SEScope " << default_primitive_se_scope; + VLOG(0) << "Established host SEScope " << host_se_scope; +} + +TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") + .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); + +} // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d79..60c6303412e46 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -74,7 +74,7 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { +void CheckAndUpdateHostConsistency(TargetMap* targets, Target* host) { Map new_targets; for (auto& it : *targets) { auto target = it.second; diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc new file mode 100644 index 0000000000000..af56be2f909a8 --- /dev/null +++ b/tests/cpp/target/se_scope_test.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace target { +namespace { + +TEST(SEScope, MemoizedConstructors) { + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + SEScope se_scope_a = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"); + SEScope se_scope_b = SEScope::MakeSEScope(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"), se_scope_a); + EXPECT_EQ(SEScope::MakeSEScope(kDLCPU, 1, target_b, "global"), se_scope_b); + EXPECT_NE(SEScope::MakeSEScope(kDLCUDA, 2, target_a, "local"), se_scope_a); + EXPECT_NE(SEScope::MakeSEScope(kDLCPU, 3, target_b, "local"), se_scope_a); + EXPECT_NE(SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"), se_scope_a); +} + +TEST(SEScope, Join_Defined) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, 3); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, -1, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_EQ(actual.value(), expected); +} + +TEST(SEScope, Join_Undefined) { + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, 3); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, 4); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); +} + +TEST(SEScope, Default) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope::MakeSEScope(kDLCUDA, -1, Target(), "global"); + SEScope rhs = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "local"); + SEScope actual = SEScope::Default(lhs, rhs); + SEScope expected = SEScope::MakeSEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_EQ(actual, expected); +} + +TEST(CompilationConfig, Constructor) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + + Target cuda_target = Target("nvidia/tesla-p40"); + Target default_cpu_target = Target("llvm"); + + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/Target()); + + EXPECT_EQ(config.legacy_target_map.size(), 1); + EXPECT_EQ((*config.legacy_target_map.begin()).second->str(), cuda_target->str()); + EXPECT_FALSE(config.optional_host_target.defined()); + EXPECT_EQ(config.targets.size(), 2); + EXPECT_EQ(config.targets[0]->str(), cuda_target->str()); + EXPECT_EQ(config.targets[1]->str(), default_cpu_target->str()); + EXPECT_EQ(config.default_primitive_se_scope->device_type(), kDLCUDA); + EXPECT_EQ(config.default_primitive_se_scope->target()->str(), cuda_target->str()); + EXPECT_EQ(config.host_se_scope->device_type(), kDLCPU); + EXPECT_EQ(config.host_se_scope->target()->str(), default_cpu_target->str()); + EXPECT_EQ(config.homogeneous_target->str(), cuda_target->str()); +} + +} // namespace +} // namespace target +} // namespace tvm diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_se_scope.py new file mode 100644 index 0000000000000..0a9384fa9c046 --- /dev/null +++ b/tests/python/target/test_se_scope.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm + + +def test_make_se_scope_for_device(): + se_scope = tvm.target.make_se_scope(tvm.device("cuda")) + assert se_scope.device_type == 2 + # ie kDLCUDA + assert se_scope.virtual_device_id == 0 + assert se_scope.target is None + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_and_target(): + target = tvm.target.Target("cuda") + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_target_and_memory_scope(): + target = tvm.target.Target("cuda") + scope = "local" + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == scope + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))