From bba188bcfffb5803483e6b89f4ac08d790d202f8 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Wed, 3 Nov 2021 09:37:28 -0700 Subject: [PATCH] Adds SEScope (Storage/Execution Scope) for use as new unit of planning in 'device' planning. (#9313) [Target] Adds SEScope (Storage/Execution Scope) for use as new unit of planning in 'device' planning This is the first step in https://github.com/apache/tvm-rfcs/pull/38 to bring devices and targets together when doing device planning. I've gone ahead and also included a memory scope in this object since we will also need to propagate memory scopes across Relay expressions once this basic preparation is in place. In the meantime that field will be left as "". Once device planning works in units of SEScopes it will be possible to directly read off the device and target for any Relay sub-expression without the need for TargetMaps ort the construction of default Targets. SEScopes also support 'Join' and 'Default' operations needed when constraint solving in the device planner. You can see those in use in my scratchpad branch: https://github.com/mbs-octoml/mbs-tvm/tree/mbs-scopes This PR also brings some duplicated and the ad-hoc 'default target' handling logic together into a CompilationConfig class. (Again, see the scratchpad branch for how that will end up being used). I've placed that next to SEScope since it's main purpose is to a) establish the default SEScope for primitive ops b) establish the SEScope for the 'host' c) feed a definitive vector of Targets into device planning so it can resolve all "on_device" and "device_copy" device references to their full SEScope form. * Reworked to avoid global SEScopeCache. Realized while working through unit tests in the sequel that it's reasonable for folks to call build multiple times with distinct Target objects, in which case the global cache would grow without bound. So instead placed the cache in the CompilationConfig class. Since that class now has everything the device planner needs to do its job, promoted it to be an FFI-able Object, which is now in compilation_config.{h,cc}. I think we can do much better with CompilationConfig, but for now keeping it to the minimum I needed to prepare for device planning from all the executor compilation codepaths. --- include/tvm/ir/attrs.h | 4 +- include/tvm/target/compilation_config.h | 170 ++++++++++ include/tvm/target/se_scope.h | 349 ++++++++++++++++++++ include/tvm/target/target.h | 18 +- python/tvm/target/__init__.py | 2 + python/tvm/target/compilation_config.py | 27 ++ python/tvm/target/se_scope.py | 22 ++ src/parser/parser.cc | 4 +- src/printer/relay_text_printer.cc | 4 +- src/relay/backend/aot_executor_codegen.cc | 4 +- src/relay/backend/build_module.cc | 11 +- src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/te_compiler.cc | 10 + src/relay/backend/te_compiler.h | 2 +- src/relay/backend/utils.h | 2 +- src/relay/backend/vm/compiler.cc | 16 +- src/relay/backend/vm/compiler.h | 7 +- src/target/compilation_config.cc | 229 +++++++++++++ src/target/se_scope.cc | 192 +++++++++++ src/target/target.cc | 15 +- tests/cpp/target/compilation_config_test.cc | 184 +++++++++++ tests/cpp/target/se_scope_test.cc | 119 +++++++ tests/python/target/test_se_scope.py | 52 +++ 23 files changed, 1410 insertions(+), 35 deletions(-) create mode 100644 include/tvm/target/compilation_config.h create mode 100644 include/tvm/target/se_scope.h create mode 100644 python/tvm/target/compilation_config.py create mode 100644 python/tvm/target/se_scope.py create mode 100644 src/target/compilation_config.cc create mode 100644 src/target/se_scope.cc create mode 100644 tests/cpp/target/compilation_config_test.cc create mode 100644 tests/cpp/target/se_scope_test.cc create mode 100644 tests/python/target/test_se_scope.py diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 715c96eb6ea5..f6c15f9590df 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -489,7 +489,7 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; - // applies two stratgies to lookup + // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h new file mode 100644 index 000000000000..facb74d6278e --- /dev/null +++ b/include/tvm/target/compilation_config.h @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.h + * \brief A helper class to collect all the targets in canonical form necessary for compilation. + * CAUTION: Preliminary, currently only used to support device planning, very likely to change. + */ + +#ifndef TVM_TARGET_COMPILATION_CONFIG_H_ +#define TVM_TARGET_COMPILATION_CONFIG_H_ + +#include + +namespace tvm { + +/*! + * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to + * compile a Relay module. All centralizes any setup and validation logic needed to transition + * from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly + * (eg a a list of \p Targets) to the configuration. + * + * CAUTION: This is subject to change as we rework compilation options in general. See + * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0028-command-line-registry-composition.md. + * So far this class is only focussed on carrying just the configuration needed by PlanDevices, + * and removing target-munging code duplication and inconsistencies between the three major build + * flows for the VM (relay/backend/vm/compile.cc), Graph/AOT (relay/backend/build_module.cc) and + * Interpreter (relay/backend/interpreter.cc). Over time we expect more global compiler + * configuration (eg for executor and runtime config, for system memory pool configuration, etc) + * to migrate into this class, and instances thereof to be attached to \p IRModules using a + * well-known attribute. + */ +class CompilationConfigNode : public Object { + public: + /*! + * \brief The legacy targets map, mapping device type to \p Targets. Does not include any + * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, + * though we want to get rid of that limitation. + * + * CAUTION: Since keys are \p Integers they are compared by object equality not integer + * value. + * + * TODO(mbs): Remove once codegen updated for new target conventions. + */ + TargetMap legacy_target_map; + + /*! + * \brief The host target. Used for 'scalar' data and code (such as shapes and shape + * functions) and residual Relay expressions and data (such as conditionals and ADTs). + */ + Target host_target; + + /*! + * \brief Vector of all available targets for primitive operators. May contain a \p Target + * for the same device type as for the \p host_target, however the \p host_target should + * be preferred for all host computations and data. + */ + Array primitive_targets; + + /*! + * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular + * device. + */ + SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + + /*! \brief SEScope for the host. */ + SEScope host_se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all + * primitives are compiled for this target only. + * + * This is to support legacy passes which have not been adapted to hetrogeneous execution and + * rely on an implicit global \p Target to be in scope. + * + * TODO(mbs): Remove once all passes are 'hetrogeneous aware'. + */ + Target optional_homogeneous_target; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns a \p SEScope agreeing with \p se_scope on all its constrained fields, however: + * - If the target is null then it is filled in from the known available primitive targets by + * matching on device type. Fails if no such target is known. + * - The returned object is unique for the field values w.r.t. all other \p SEScopes returned + * by this method. + * + * We call the result the 'canonical' \p SEScope. Two canonical \p SEScopes are structurally + * equal if and only if they are pointer equal. + */ + SEScope CanonicalSEScope(const SEScope& se_scope) const; + + static constexpr const char* _type_key = "CompilationConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object) + + private: + /*! + * \brief Establishes the default \p SEScope for primitives and the \p SEScope for the host + * given: + * - the vector of available primitive \p Targets. + * - any host \p Target. + * - any "relay.fallback_device_type" attribute on \p pass_ctx. + * - whether the LLVM backend is available. + * If necessary, creates new default \p Targets to match the required devices. + * + * NOTE: The implementation is a bit convoluted since it tries to maintain backwards + * compatibility with legacy methods for conveying \p Targets. + * + * CAUTION: Recreated the primitive_targets so that they all have the given/constructed + * host_target as their host (cf CheckAndUpdateHostConsistency). + */ + void EstablishDefaultSEScopes(const transform::PassContext& pass_ctx); + + /*! + * \brief Returns a freshly constructed \p Target to represent \p device_type. + */ + static Target MakeDefaultTarget(DLDeviceType device_type); + + /*! + * \brief Return the \p Target to use for \p device_type. Fail if no such target exists. + */ + Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const; + + /*! + * \brief A cache of constructed SEScopes. + */ + mutable SEScopeCache se_scope_cache_; + + friend class CompilationConfig; +}; + +/*! + * \brief Managed reference class to \p CompilationConfig + * + * \sa CompilationConfig + */ +class CompilationConfig : public ObjectRef { + public: + /*! + * \brief Constructs the compilation config given the available \p Targets in the + * \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use + * 'relay.fallback_device_type' and the availability of the LLVM compilation module + * to decide on appropriate default devices. + */ + TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, + Target optional_host_target_arg); + + TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode); +}; + +} // namespace tvm + +#endif // TVM_TARGET_COMPILATION_CONFIG_H_ diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h new file mode 100644 index 000000000000..981a0b85ab13 --- /dev/null +++ b/include/tvm/target/se_scope.h @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.h + * \brief A compile time representation for a Storage or Execution Scope. + */ + +#ifndef TVM_TARGET_SE_SCOPE_H_ +#define TVM_TARGET_SE_SCOPE_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { + +/*! + * Abstract label for an area of memory. + * + * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation + * of a memory pool in the future. Please try to use this alias instead of String to aid future + * code migration. + */ +using MemoryScope = String; + +/*! + * \brief Describes at compile time where data is to be stored down to the device and memory + * scope level, or where execution is to take place, down to the device level. It is a quadruple of: + * - A \p device_type (\p DLDeviceType). May be kInvalidDeviceType if unconstrained. + * - A \p virtual_device_id (\p int). This allows us to distinguish distinct devices + * with the same \p Target, for example in a multi-GPU system. May be -1 if unconstrained. + * See "Virtual Devices" below. + * - A \p target (\p Target) describing how to compile code for the intended device. May be null + * if unconstrained. + * - A \p memory_scope (\p MemoryScope, which is currently just \p String) describing which memory + * area is to be used to hold data. May be "" if unconstrained. See "Memory Scopes and Devices" + * below. + * + * Some or all of these fields may be unconstrained, signaling that device planning is free to + * choose a value consistent with the whole program. However if a \p target is given then the \p + * device_type must equal \p target->kind->device_type. + * + * Note that currently we assume if a function returns its result on a particular device + * then the function body is also executed on that device. See the overview comment in + * src/relay/transforms/device_planner.cc for more details. + * + * By 'data' we include both tensors and additional supporting datastructures such as shapes, + * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside + * on a 'CPU'-like device with good support for scalars. + * + * By 'execution' we include both (fused) primitive operators, and all the Relay expressions + * surrounding them which coordinates data and control flow. Again, typically non-primitive + * operators must be executed on a 'CPU'-like device with good support for control flow. + * + * Since TVM targets such a wide range of systems it is not possible for \p SEScope to impose + * much semantics on these fields, particularly for \p virtual_device_id and \p memory_scope. + * Instead we assume downstream passes and codegen will interpret an validate these fields + * appropriately. + * + * Targets vs Devices + * ------------------ + * Generally \p Targets (a compile-time only datastructue) describe compiler options for a specific + * microarchitecture and toolchain, while \p Devices (a runtime datastructure also available at + * compile time) describe a physical device on the target system. Obviously the target must agree + * with the device's microarchitecture, but we otherwise don't impose any constraints between them: + * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf + * out of a particular primitive. + * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. + * + * Traditionally TVM assumes at most one \p Target per \p DLDeviceType. We are moving away from that + * assumption. + * + * Virtual vs Physical Devices + * --------------------------- + * The \p virtual_device_id may be used by downstream passes or the runtime to help decide which + * \p device_id to use for a particular physical runtime \p Device. For example: + * - Some runtimes may support passing in an array of actual `device` specifications, and the + * \p virtual_device_id can be used at runtime as an index into that array. + * - Some runtimes may support dynamically allocating computations to physical devices. On these + * systems a large space of \p virtual_device_ids could be used at compile time, even though + * at runtime only a few physical devices will be present. + * + * The \p virtual_device_id may also be left unconstrained if not needed. + * + * Memory Scopes and Devices + * ------------------------- + * Multi-device systems can have complex memory hierarchies. For example + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCPU, 1, "llvm", "global") + * \endcode + * could denote: + * - The same memory area accessible from two separate CPUs without any CPU affinity; + * - Distinct memory areas in a NUMA architecture for which cross-device access is handled + * by the memory system; + * - Outright distinct memory areas, where one device cannot directly address the memory of + * another. + * + * Similarly: + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCUDA, 0, "cuda", "host") + * \endcode + * could denote the same memory area, but with very different access costs. + * + * Furthermore, not all memory scopes are accessible to all devices, and it is possible for + * a memory scope to only be accessible to a device when code is compiled with particular + * \p Target options. + * + * \p SEScopes themselves have no system-level understanding. Currently device planning will + * simply insert "device_copy" operators wherever \p SEScopes are not exactly pointwise equal. + * We may revisit this in the future as the work on memory pools matures. + * + * Joining and Defaulting + * ---------------------- + * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees + * with both join arguments. Eg: + * \code + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) + * => (kDLCPU, 3, "llvm", "global") + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "local)) + * => null (no join possible) + * \endcode + * + * Related to 'join' is 'default', which only takes constrained fields from the rhs when the + * lhs is unconstrained: + * \code + * Default(kDLCPU, -1, "llvm", "local"), (kDLCPU, 3, null, "global")) + * => (kDLCPU, 3, "llvm", "local") + * \endcode + * + * These operations are needed during device planning. + * + */ +class SEScopeNode : public AttrsNode { + public: + /*! + * \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is + * known then this will be equal to \p target->kind->device_type. If \p target is null then the + * target is to be determined later. + * + * This is needed to support the legacy "on_device" and "device_copy" calls which only allow + * a \p DLDeviceTypes (as an integer) to be given. + * + * kInvalidDeviceType denotes unconstrained. + */ + int device_type_int; + + DLDeviceType device_type() const { return static_cast(device_type_int); } + + /*! + * \brief The device identifier for the virtual device. This must be resolved to a physical + * device identifier either during compilation or at runtime. + * + * -1 denotes unconstrained. + */ + int virtual_device_id; + + /*! + * \brief The \p Target describing how to compile for the virtual device. + * + * Null denotes unconstrained. Note that if a target later becomes known for this \p SEScope + * then it must be consistent with the \p device_type if already known. This is enforced by the + * Join and Default methods. + */ + Target target; + + /*! + * \brief The scope of memory w.r.t. the virtual device which holds data. + * + * Empty denotes unconstrained. + */ + MemoryScope memory_scope; + + /*! + * \brief Returns true if scope is fully unconstrained, ie no target/device type, device id + * or memory scope is specified. + */ + bool IsFullyUnconstrained() const { + return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 && + memory_scope.empty(); + } + + /*! + * \brief Returns true if scope is fully constrained, ie target, device id and memory scope are + * all specified. + */ + bool IsFullyConstrained() const { + return target.defined() && virtual_device_id != -1 && !memory_scope.empty(); + } + + /*! + * \brief Returns the (virtual) \p Device implied by this \p SEScope. Both the \p device_type and + * \p virtual_device_must be constrained. The returned \p Device may not correspond to any + * physical device available at compile time or even runtime: see "Virtual vs Physical Devices" + * above. + */ + Device ToDevice() const { + ICHECK(device_type() != kInvalidDeviceType); + ICHECK(virtual_device_id != -1); + Device device; + device.device_type = device_type(); + device.device_id = virtual_device_id; + return device; + } + + TVM_DECLARE_ATTRS(SEScopeNode, "SEScope") { + TVM_ATTR_FIELD(device_type_int) + .describe("The type of the virtual device.") + .set_default(kInvalidDeviceType); + TVM_ATTR_FIELD(virtual_device_id) + .describe("The device id of the virtual device.") + .set_default(-1); + TVM_ATTR_FIELD(target) + .describe("The target describing how to compile for the virtual device.") + .set_default(Target()); + TVM_ATTR_FIELD(memory_scope) + .describe("The area of memory w.r.t. the virtual device where data is stored.") + .set_default(""); + } + + friend class SEScope; +}; + +/*! + * \brief Managed reference class to \p SEScopeNode. + * + * \sa SEScopeNode. + */ +class SEScope : public ObjectRef { + public: + /*! + * \brief Construct an SEScope. + * \param device_type The device type for the virtual device, or kInvalidDeviceType if + * unconstrained. If \p target is defined then must match its \p target->kind->device_type. + * \param virtual_device_id The device id for the virtual device, or -1 if unconstrained. + * \param target The target describing how to compile for the virtual device, or null if + * unconstrained. + * \param memory_scope The memory scope w.r.t. the virtual device which holds data, or "" if + * unconstrained. + * \return The SEScope + */ + explicit SEScope(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique fully unconstrained \p SEScope. */ + static SEScope FullyUnconstrained(); + + /*! + * \brief Returns the \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * The target and memory scope will be unconstrained. + */ + static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + ICHECK_GT(device_type, 0); + return SEScope(device_type, virtual_device_id); + } + static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type), virtual_device_id); + } + static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type->value), virtual_device_id); + } + + /*! \brief Returns the \p SEScope for \p device. */ + static SEScope ForDevice(const Device& device) { + return ForDeviceType(device.device_type, device.device_id); + } + + /*! \brief Returns the \p SEScope for \p device and \p target. */ + static SEScope ForDeviceAndTarget(const Device& device, Target target) { + return SEScope(device.device_type, device.device_id, std::move(target)); + } + + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ + TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, + MemoryScope memory_scope) { + return SEScope(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); + } + + /*! + * \brief Returns the 'join' of \p lhs and \p rhs. The result will agree pointwise with + * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such + * join exists, ie there's disagreement on at least one constrained field. + */ + static Optional Join(const SEScope& lhs, const SEScope& rhs); + + /*! + * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any + * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. + */ + static SEScope Default(const SEScope& lhs, const SEScope& rhs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + + friend class SEScopeCache; // Private implementation helper. +}; + +/*! + * \brief A cache of \p SEScopes. This can be used: + * - To avoid ending up with lots of identical instances, since the space of SEScopes for any + * one compilation is very small but the number of points they need to be constructed can + * be very large (eg during device planning). + * - So we can assume \p SEScopes are pointer equal if and only if they are structurally equal. + * This simplifies the unification of 'device domains' which are built on \p SEScopes. + */ +class SEScopeCache { + public: + /*! \brief Returns the unique \p SEScope representing given fields. */ + SEScope Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique \p SEScope structurally equal to the given \p se_scope. */ + SEScope Unique(const SEScope& scope); + + private: + /*! \brief Already constructed SEScopes. */ + std::unordered_set cache_; +}; + +} // namespace tvm + +#endif // TVM_TARGET_SE_SCOPE_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64a1023158e1..e0d34c87dda7 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -110,7 +110,12 @@ class TargetNode : public Object { /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; + bool SEqualReduce(const TargetNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; + static constexpr const char* _type_key = "Target"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: @@ -179,6 +184,9 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; + +using TargetMap = Map; + /*! * \brief Check and update host field of the given legacy target and target host pair. * Note that this function is for legacy target api compatibility issue only, not @@ -187,22 +195,24 @@ class Target : public ObjectRef { * \param host The pointer to a Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Target* target, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with values being Target objects + * \param target_map The pointer to a Map objects with values being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with keys being Target objects + * \param ir_modules The pointer to a Map objects with keys being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1e906cb381d8..5dc95c3ae675 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,6 +71,8 @@ riscv_cpu, hexagon, ) +from .se_scope import make_se_scope +from .compilation_config import make_compilation_config from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/compilation_config.py b/python/tvm/target/compilation_config.py new file mode 100644 index 000000000000..2796ec4b5135 --- /dev/null +++ b/python/tvm/target/compilation_config.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating CompilationConfigs.""" +from . import _ffi_api + + +def make_compilation_config(ctxt, targets, host_target=None): + """Returns a CompilationConfig appropriate for targets and an optional host_target. + Currently intended just for unit tests and will be replaced by a Python CompilationConfig + class in the future. Note that targets must be a dictionary from IntImm objects to Targets + and we do not support any of the lighter-weight conventions used by the various build(...) + APIs.""" + return _ffi_api.MakeCompilationConfig(ctxt, targets, host_target) diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/se_scope.py new file mode 100644 index 000000000000..83df5ae3448a --- /dev/null +++ b/python/tvm/target/se_scope.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating SEScopes.""" +from . import _ffi_api + + +def make_se_scope(device, target=None, memory_scope=""): + return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ebd6566889dc..486799603354 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1437,8 +1437,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto tbl = tvm::ReflectionVTable::Global(); - auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + auto attr_obj = + tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ea97bb35a09f..9eca038e5c93 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -773,8 +773,6 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { printed_attr << "?"; } else if (auto str_obj = value.as()) { printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (const auto* on_device_attrs = value.as()) { - printed_attr << "device_type=" << on_device_attrs->device_type; } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { @@ -787,7 +785,7 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { } Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), true); + return PrintAttr(GetRef(op), /*meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3c9c35c4f254..7e5702296542 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -715,7 +715,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map targets = args[1]; + TargetMap targets = args[1]; init(mod, targets); }); } else if (name == "codegen") { @@ -758,7 +758,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: - void init(void* mod, Map tmp) { + void init(void* mod, TargetMap tmp) { tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7005e94c2411..78978e192f0e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -43,7 +43,6 @@ Pass LabelOps(); } namespace backend { -using TargetsMap = Map; using namespace tvm::relay::transform; /*! @@ -56,7 +55,7 @@ struct BuildOutput { }; struct ExecutorCodegen { - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { for (const auto& pair : targets) { VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); @@ -307,7 +306,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + IRModule Optimize(IRModule relay_module, const TargetMap& targets, const std::unordered_map& params) { targets_ = targets; // No target_host setup it seems. @@ -444,7 +443,7 @@ class RelayBuildModule : public runtime::ModuleNode { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Update all the targets in the targets_ TargetsMap + // Update all the targets in the targets_ TargetMap CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. @@ -540,7 +539,7 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! \brief target host device */ tvm::Target target_host_; /*! \brief parameters */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index debd669126c4..d32ded379688 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -619,7 +619,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; + TargetMap tmp = args[1]; tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index ed774eccd8dd..e1ed3d47d36d 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -578,6 +578,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } + Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Nothing to lower inside primitive functions. + return GetRef(function_node); + } else { + return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node); + } + } + Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { Call call = GetRef(call_node); // Look for (indirect) calls to primitives. @@ -620,6 +629,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Replace with direct call to lowered primitive, and attach annotations to record calling // convention. + // =====> in new call_lowered form return Call(pair.first, args, pair.second); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e3b7d46457ad..d0401e9605f7 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -148,7 +148,7 @@ void UpdateFunctionMetadata(Function relay_func, * \param dev_type * \return Target */ -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets); /*! * \brief Update the "main" control function's metadata diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a647aa1a3fd2..f89a099b0d4f 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -437,7 +437,7 @@ inline bool IsAutoSchedulerEnabled() { * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPassPrefix(const Map& targets, bool is_vm); +Array GetPassPrefix(const TargetMap& targets, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 6a085adad3d1..c4c50c6c5646 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,7 +80,6 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; -using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -251,7 +250,7 @@ int GetFallbackDevice() { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), @@ -464,7 +463,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -540,7 +539,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } } - CCacheKey key(func, target); + tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering @@ -910,7 +909,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, + const tvm::Target& target_host) { exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -976,7 +976,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { +transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; @@ -1022,10 +1022,10 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, const Target& target_host_arg) { VLOG_CONTEXT << "VMCompiler::OptimizeModule"; - TargetsMap targets = targets_arg; + TargetMap targets = targets_arg; Target target_host = target_host_arg; CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index af3c5bccbeff..5b51d7821d78 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; -using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Target devices. */ - TargetsMap targets_; + TargetMap targets_; /*! \brief Target host device. */ tvm::Target target_host_; /*! \brief Global shared meta data */ diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc new file mode 100644 index 000000000000..b3491d656625 --- /dev/null +++ b/src/target/compilation_config.cc @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.cc + * \brief Implementation of \p CompilationConfig for collecting \p Targets. + */ + +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(CompilationConfigNode); + +void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { + v->Visit("legacy_target_map", &legacy_target_map); + v->Visit("host_target", &host_target); + v->Visit("primitive_targets", &primitive_targets); + v->Visit("default_primitive_se_scope", &default_primitive_se_scope); + v->Visit("host_se_scope", &host_se_scope); + v->Visit("optional_homogenous_target", &optional_homogeneous_target); + // NOTE: The se_scope_cache_ is not accessible via FFI. +} + +SEScope CompilationConfigNode::CanonicalSEScope(const SEScope& se_scope) const { + if (se_scope->target.defined()) { + return se_scope_cache_.Unique(se_scope); + } + DLDeviceType device_type = se_scope->device_type(); + // TODO(mbs): Proper diagnostics. + CHECK(device_type != kInvalidDeviceType) + << "SEScope annotations must include at least a device_type"; + Target target = FindPrimitiveTargetOrFail(se_scope->device_type()); + return se_scope_cache_.Unique( + SEScope(device_type, se_scope->virtual_device_id, target, se_scope->memory_scope)); +} + +void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContext& pass_ctx) { + // + // Gather the hints as to what our default device type for the 'host' should be, and + // create an appropriate target if we don't already have one. + // + DLDeviceType host_device_type; + if (host_target.defined()) { + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + if (host_device_type != kDLCPU) { + LOG(WARNING) << "Using the given host target '" << host_target << "' of non-CPU device type " + << host_device_type << " for all host operations and data"; + } else { + LOG(INFO) << "Using the given host target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } + } else if (primitive_targets.size() == 1 && + primitive_targets.front()->kind->device_type == kDLCPU) { + // In the homogenous case without an explicit host target just use the given target so long as + // it's a CPU. However make sure we 'forget' any host it may already have. + host_device_type = kDLCPU; + host_target = Target(primitive_targets.front()); + LOG(INFO) << "Using the unique target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } else { + // Fallback. + host_device_type = kDLCPU; + // Even if the list of available targets already includes one for kDLCPU we won't use it + // since its options may not be appropriate for host code (eg shape functions). Instead, + // create a fresh default Target. + host_target = MakeDefaultTarget(host_device_type); + LOG(WARNING) << "Using the default host target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } + ICHECK(host_target.defined()); + + // + // Establish the host SEScope. + // + host_se_scope = se_scope_cache_.Unique(SEScope(host_device_type, + /*virtual_device_id=*/0, host_target)); + + // + // Now that we've settled on a host, make sure all the primitive Targets agree on it for + // their 'host' field. This mutates the primitives. + // + Array new_primitve_targets; + new_primitve_targets.reserve(primitive_targets.size()); + for (const auto& primitive_target : primitive_targets) { + new_primitve_targets.push_back(Target(primitive_target, host_target)); + } + primitive_targets = new_primitve_targets; + + // + // Gather the hints as to what our default device type for primitives should be. + // + DLDeviceType default_primitive_device_type; + Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + const int64_t v = opt_fallback_dev.value()->value; + if (v <= 0) { + LOG(FATAL) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " + << v; + default_primitive_device_type = kDLCPU; + } else { + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + } else if (primitive_targets.size() == 1) { + // In the homogeneous case there's no free choice. + default_primitive_device_type = + static_cast(primitive_targets.front()->kind->device_type); + LOG(INFO) << "Using the unique target '" << primitive_targets.front() << "' of device type " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } else { + // Fallback. Note that we'll require a primitive Target of kDLCPU device_type to be given + // and won't manufacture one out of thin air. + default_primitive_device_type = kDLCPU; + LOG(WARNING) << "Using " << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + + // + // Establish the default primitive SEScope, choosing a known Target to match the device type. + // + default_primitive_se_scope = se_scope_cache_.Unique( + SEScope(default_primitive_device_type, + /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); +} + +/* static */ Target CompilationConfigNode::MakeDefaultTarget(DLDeviceType device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") { + if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { + // LLVM is available. + return Target("llvm"); + } else { + // LLVM is not available. + return Target("stackvm"); + } + } else { + return Target(name); + } +} + +Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type) const { + auto itr = std::find_if( + primitive_targets.begin(), primitive_targets.end(), + [device_type](const Target& target) { return target->kind->device_type == device_type; }); + CHECK(itr != primitive_targets.end()) << "No target for device type " << device_type << " in the " + << primitive_targets.size() << " given by the targets list"; + return *itr; +} + +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + TargetMap legacy_target_map_arg, + Target optional_host_target_arg) { + VLOG_CONTEXT << "CompilationConfig"; + + auto node = make_object(); + + for (const auto& pair : legacy_target_map_arg) { + VLOG(0) << "Available primitive target " << pair.first << " = '" << pair.second << "'"; + } + if (optional_host_target_arg.defined()) { + VLOG(0) << "Available host target '" << optional_host_target_arg << "'"; + } + + // Capture the arguments in our representation. + for (const auto& pair : legacy_target_map_arg) { + node->primitive_targets.push_back(pair.second); + } + node->host_target = optional_host_target_arg; + + // Complete the targets vector and establish default scopes. After this primitive_targets will + // contain the definitive list of all required targets, target_host will be defined, and + // all primitive targets will have host target_host. + node->EstablishDefaultSEScopes(pass_ctx); + + // LEGACY: Reconstruct the target map with all the primitive targets. + for (const auto& primitive_target : node->primitive_targets) { + node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); + } + + ICHECK(node->default_primitive_se_scope->target.defined()); + ICHECK(node->host_se_scope->target.defined()); + ICHECK_GT(node->primitive_targets.size(), 0U); + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. Make this easy to detect. + node->optional_homogeneous_target = + node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); + + for (const auto& target : node->primitive_targets) { + LOG(INFO) << "Target '" << target << "' of device type " << target->kind->device_type + << " is available for primitives"; + } + LOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; + LOG(INFO) << "Using host scope " << node->host_se_scope; + + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("target.MakeCompilationConfig") + .set_body_typed([](const transform::PassContext& pass_ctx, TargetMap legacy_target_map, + Target optional_host_target) -> CompilationConfig { + return CompilationConfig(pass_ctx, std::move(legacy_target_map), + std::move(optional_host_target)); + }); + +} // namespace tvm diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc new file mode 100644 index 000000000000..150a883cb565 --- /dev/null +++ b/src/target/se_scope.cc @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.cc + * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + */ +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(SEScopeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "SEScope("; + if (node->IsFullyUnconstrained()) { + p->stream << "?"; + } else { + bool need_sep = false; + if (node->device_type() != kInvalidDeviceType) { + p->stream << "device_type=" << node->device_type(); + need_sep = true; + } + if (node->virtual_device_id >= 0) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "virtual_device_id=" << node->virtual_device_id; + need_sep = true; + } + if (node->target.defined()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "target='" << node->target << "'"; + need_sep = true; + } + if (!node->memory_scope.empty()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "memory_scope='" << node->memory_scope << "'"; + } + } + p->stream << ")"; + }); + +SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + ICHECK(!target.defined() || device_type == target->kind->device_type) + << "target '" << target << "' has device type " << target->kind->device_type + << " but scope has device type " << device_type; + auto node = make_object(); + node->device_type_int = device_type; + node->virtual_device_id = virtual_device_id; + node->target = std::move(target); + node->memory_scope = std::move(memory_scope); + data_ = std::move(node); +} + +/* static */ SEScope SEScope::FullyUnconstrained() { + static const SEScope unconstrained{}; + return unconstrained; +} + +/* static */ +Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType joined_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + joined_device_type = lhs->device_type(); + if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) { + return {}; + } + } else { + joined_device_type = rhs->device_type(); + } + int joined_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + joined_virtual_device_id = lhs->virtual_device_id; + if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) { + return {}; + } + } else { + joined_virtual_device_id = rhs->virtual_device_id; + } + Target joined_target; + if (lhs->target.defined()) { + joined_target = lhs->target; + if (rhs->target.defined() && lhs->target != rhs->target) { + return {}; + } + } else { + joined_target = rhs->target; + } + MemoryScope joined_memory_scope; + if (!lhs->memory_scope.empty()) { + joined_memory_scope = lhs->memory_scope; + if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { + return {}; + } + } else { + joined_memory_scope = rhs->memory_scope; + } + return SEScope(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope); +} + +/* static */ +SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType defaulted_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + defaulted_device_type = lhs->device_type(); + } else { + defaulted_device_type = rhs->device_type(); + } + int defaulted_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + defaulted_virtual_device_id = lhs->virtual_device_id; + } else { + defaulted_virtual_device_id = rhs->virtual_device_id; + } + Target defaulted_target; + if (lhs->target.defined()) { + defaulted_target = lhs->target; + } else { + // We can only default to the rhs's target if it is consistent with the device type + if (rhs->target.defined() && rhs->target->kind->device_type == defaulted_device_type) { + defaulted_target = rhs->target; + } + // else: leave as null + } + MemoryScope defaulted_memory_scope; + if (!lhs->memory_scope.empty()) { + defaulted_memory_scope = lhs->memory_scope; + } else { + defaulted_memory_scope = rhs->memory_scope; + } + return SEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); +} + +SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); + auto itr = cache_.find(prototype); + if (itr == cache_.end()) { + VLOG(1) << "added new scope " << prototype; + cache_.emplace(prototype); + return prototype; + } else { + VLOG(1) << "reusing '" << *itr << "' for '" << prototype << "'"; + ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); + if (prototype->target.defined()) { + ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); + } + return *itr; + } +} + +SEScope SEScopeCache::Unique(const SEScope& scope) { + return Make(scope->device_type(), scope->virtual_device_id, scope->target, scope->memory_scope); +} + +TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") + .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); + +} // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d7..d1c85c583b3b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -74,7 +74,7 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { +void CheckAndUpdateHostConsistency(TargetMap* targets, Target* host) { Map new_targets; for (auto& it : *targets) { auto target = it.second; @@ -531,6 +531,19 @@ Optional TargetNode::GetHost() const { return GetRef>(this->host.as()); } +bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { + return equal(kind.get(), other->kind.get()) && equal(host, other->host) && + equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); +} + +void TargetNode::SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(kind.get()); + hash_reduce(host); + hash_reduce(tag); + hash_reduce(keys); + hash_reduce(attrs); +} + /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc new file mode 100644 index 000000000000..5c2b7990a498 --- /dev/null +++ b/tests/cpp/target/compilation_config_test.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +Target TestCpuTarget() { return Target("llvm -mcpu arm64"); } + +Target TestCudaTarget() { return Target("nvidia/tesla-p40"); } + +Target TestDefaultCpuTarget() { return Target("llvm"); } + +Target TestExtDevTarget() { return Target("ext_dev"); } + +CompilationConfig TestCompilationConfig() { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), TestCpuTarget()); + return CompilationConfig(pass_ctx, legacy_target_map, TestDefaultCpuTarget()); +} + +TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 1); + EXPECT_TRUE(StructuralEqual()((*config->legacy_target_map.begin()).second, + Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 1); + EXPECT_TRUE( + StructuralEqual()(config->primitive_targets[0], Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cuda_target, host_target))); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, host_target); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 2); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_InvalidAttribute) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kInvalidDeviceType))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLMetal))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLExtDev)), TestExtDevTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, CanonicalSEScope) { + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + CompilationConfig config = TestCompilationConfig(); + + { + SEScope in = SEScope(kDLCPU); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cpu_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } + { + SEScope in = SEScope(kDLCUDA); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } +} + +TEST(CompilationConfig, CanonicalSEScope_NoDevice) { + CompilationConfig config = TestCompilationConfig(); + SEScope fully_unconstrained; + EXPECT_ANY_THROW(config->CanonicalSEScope(fully_unconstrained)); + SEScope missing_device(kInvalidDeviceType, 3, {}, "local"); + EXPECT_ANY_THROW(config->CanonicalSEScope(missing_device)); +} + +TEST(CompilationConfig, CanonicalSEScope_NoMatchingTarget) { + CompilationConfig config = TestCompilationConfig(); + SEScope no_such_target(kDLMetal); + EXPECT_ANY_THROW(config->CanonicalSEScope(no_such_target)); +} + +} // namespace +} // namespace tvm diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc new file mode 100644 index 000000000000..166ba46faf37 --- /dev/null +++ b/tests/cpp/target/se_scope_test.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +TEST(SEScope, Join_Defined) { + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, -1, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, target_a, "global"); + SEScope rhs = SEScope(kDLCUDA, 3); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCUDA, 2, target_a); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 2, target_a); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = rhs; + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } +} + +TEST(SEScope, Join_Undefined) { + { + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCPU); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, 4); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda")); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda")); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"), "local"); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"), "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } +} + +TEST(SEScope, Default) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, Target(), "global"); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "local"); + SEScope actual = SEScope::Default(lhs, rhs); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual, expected)); +} + +TEST(SEScope, Constructor_Invalid) { EXPECT_ANY_THROW(SEScope(kDLCPU, -1, Target("cuda"))); } + +TEST(SEScopeCache, Memoized) { + SEScopeCache cache; + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + SEScope se_scope_a = cache.Make(kDLCUDA, 3, target_a, "local"); + SEScope se_scope_b = cache.Make(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), se_scope_a); + EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), se_scope_b); + EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), se_scope_a); +} + +} // namespace +} // namespace tvm diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_se_scope.py new file mode 100644 index 000000000000..0a9384fa9c04 --- /dev/null +++ b/tests/python/target/test_se_scope.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm + + +def test_make_se_scope_for_device(): + se_scope = tvm.target.make_se_scope(tvm.device("cuda")) + assert se_scope.device_type == 2 + # ie kDLCUDA + assert se_scope.virtual_device_id == 0 + assert se_scope.target is None + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_and_target(): + target = tvm.target.Target("cuda") + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_target_and_memory_scope(): + target = tvm.target.Target("cuda") + scope = "local" + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == scope + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))