diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 715c96eb6ea5..f6c15f9590df 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -489,7 +489,7 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; - // applies two stratgies to lookup + // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h new file mode 100644 index 000000000000..facb74d6278e --- /dev/null +++ b/include/tvm/target/compilation_config.h @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.h + * \brief A helper class to collect all the targets in canonical form necessary for compilation. + * CAUTION: Preliminary, currently only used to support device planning, very likely to change. + */ + +#ifndef TVM_TARGET_COMPILATION_CONFIG_H_ +#define TVM_TARGET_COMPILATION_CONFIG_H_ + +#include + +namespace tvm { + +/*! + * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to + * compile a Relay module. All centralizes any setup and validation logic needed to transition + * from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly + * (eg a a list of \p Targets) to the configuration. + * + * CAUTION: This is subject to change as we rework compilation options in general. See + * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0028-command-line-registry-composition.md. + * So far this class is only focussed on carrying just the configuration needed by PlanDevices, + * and removing target-munging code duplication and inconsistencies between the three major build + * flows for the VM (relay/backend/vm/compile.cc), Graph/AOT (relay/backend/build_module.cc) and + * Interpreter (relay/backend/interpreter.cc). Over time we expect more global compiler + * configuration (eg for executor and runtime config, for system memory pool configuration, etc) + * to migrate into this class, and instances thereof to be attached to \p IRModules using a + * well-known attribute. + */ +class CompilationConfigNode : public Object { + public: + /*! + * \brief The legacy targets map, mapping device type to \p Targets. Does not include any + * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, + * though we want to get rid of that limitation. + * + * CAUTION: Since keys are \p Integers they are compared by object equality not integer + * value. + * + * TODO(mbs): Remove once codegen updated for new target conventions. + */ + TargetMap legacy_target_map; + + /*! + * \brief The host target. Used for 'scalar' data and code (such as shapes and shape + * functions) and residual Relay expressions and data (such as conditionals and ADTs). + */ + Target host_target; + + /*! + * \brief Vector of all available targets for primitive operators. May contain a \p Target + * for the same device type as for the \p host_target, however the \p host_target should + * be preferred for all host computations and data. + */ + Array primitive_targets; + + /*! + * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular + * device. + */ + SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + + /*! \brief SEScope for the host. */ + SEScope host_se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all + * primitives are compiled for this target only. + * + * This is to support legacy passes which have not been adapted to hetrogeneous execution and + * rely on an implicit global \p Target to be in scope. + * + * TODO(mbs): Remove once all passes are 'hetrogeneous aware'. + */ + Target optional_homogeneous_target; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns a \p SEScope agreeing with \p se_scope on all its constrained fields, however: + * - If the target is null then it is filled in from the known available primitive targets by + * matching on device type. Fails if no such target is known. + * - The returned object is unique for the field values w.r.t. all other \p SEScopes returned + * by this method. + * + * We call the result the 'canonical' \p SEScope. Two canonical \p SEScopes are structurally + * equal if and only if they are pointer equal. + */ + SEScope CanonicalSEScope(const SEScope& se_scope) const; + + static constexpr const char* _type_key = "CompilationConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object) + + private: + /*! + * \brief Establishes the default \p SEScope for primitives and the \p SEScope for the host + * given: + * - the vector of available primitive \p Targets. + * - any host \p Target. + * - any "relay.fallback_device_type" attribute on \p pass_ctx. + * - whether the LLVM backend is available. + * If necessary, creates new default \p Targets to match the required devices. + * + * NOTE: The implementation is a bit convoluted since it tries to maintain backwards + * compatibility with legacy methods for conveying \p Targets. + * + * CAUTION: Recreated the primitive_targets so that they all have the given/constructed + * host_target as their host (cf CheckAndUpdateHostConsistency). + */ + void EstablishDefaultSEScopes(const transform::PassContext& pass_ctx); + + /*! + * \brief Returns a freshly constructed \p Target to represent \p device_type. + */ + static Target MakeDefaultTarget(DLDeviceType device_type); + + /*! + * \brief Return the \p Target to use for \p device_type. Fail if no such target exists. + */ + Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const; + + /*! + * \brief A cache of constructed SEScopes. + */ + mutable SEScopeCache se_scope_cache_; + + friend class CompilationConfig; +}; + +/*! + * \brief Managed reference class to \p CompilationConfig + * + * \sa CompilationConfig + */ +class CompilationConfig : public ObjectRef { + public: + /*! + * \brief Constructs the compilation config given the available \p Targets in the + * \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use + * 'relay.fallback_device_type' and the availability of the LLVM compilation module + * to decide on appropriate default devices. + */ + TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, + Target optional_host_target_arg); + + TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode); +}; + +} // namespace tvm + +#endif // TVM_TARGET_COMPILATION_CONFIG_H_ diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h new file mode 100644 index 000000000000..981a0b85ab13 --- /dev/null +++ b/include/tvm/target/se_scope.h @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.h + * \brief A compile time representation for a Storage or Execution Scope. + */ + +#ifndef TVM_TARGET_SE_SCOPE_H_ +#define TVM_TARGET_SE_SCOPE_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { + +/*! + * Abstract label for an area of memory. + * + * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation + * of a memory pool in the future. Please try to use this alias instead of String to aid future + * code migration. + */ +using MemoryScope = String; + +/*! + * \brief Describes at compile time where data is to be stored down to the device and memory + * scope level, or where execution is to take place, down to the device level. It is a quadruple of: + * - A \p device_type (\p DLDeviceType). May be kInvalidDeviceType if unconstrained. + * - A \p virtual_device_id (\p int). This allows us to distinguish distinct devices + * with the same \p Target, for example in a multi-GPU system. May be -1 if unconstrained. + * See "Virtual Devices" below. + * - A \p target (\p Target) describing how to compile code for the intended device. May be null + * if unconstrained. + * - A \p memory_scope (\p MemoryScope, which is currently just \p String) describing which memory + * area is to be used to hold data. May be "" if unconstrained. See "Memory Scopes and Devices" + * below. + * + * Some or all of these fields may be unconstrained, signaling that device planning is free to + * choose a value consistent with the whole program. However if a \p target is given then the \p + * device_type must equal \p target->kind->device_type. + * + * Note that currently we assume if a function returns its result on a particular device + * then the function body is also executed on that device. See the overview comment in + * src/relay/transforms/device_planner.cc for more details. + * + * By 'data' we include both tensors and additional supporting datastructures such as shapes, + * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside + * on a 'CPU'-like device with good support for scalars. + * + * By 'execution' we include both (fused) primitive operators, and all the Relay expressions + * surrounding them which coordinates data and control flow. Again, typically non-primitive + * operators must be executed on a 'CPU'-like device with good support for control flow. + * + * Since TVM targets such a wide range of systems it is not possible for \p SEScope to impose + * much semantics on these fields, particularly for \p virtual_device_id and \p memory_scope. + * Instead we assume downstream passes and codegen will interpret an validate these fields + * appropriately. + * + * Targets vs Devices + * ------------------ + * Generally \p Targets (a compile-time only datastructue) describe compiler options for a specific + * microarchitecture and toolchain, while \p Devices (a runtime datastructure also available at + * compile time) describe a physical device on the target system. Obviously the target must agree + * with the device's microarchitecture, but we otherwise don't impose any constraints between them: + * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf + * out of a particular primitive. + * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. + * + * Traditionally TVM assumes at most one \p Target per \p DLDeviceType. We are moving away from that + * assumption. + * + * Virtual vs Physical Devices + * --------------------------- + * The \p virtual_device_id may be used by downstream passes or the runtime to help decide which + * \p device_id to use for a particular physical runtime \p Device. For example: + * - Some runtimes may support passing in an array of actual `device` specifications, and the + * \p virtual_device_id can be used at runtime as an index into that array. + * - Some runtimes may support dynamically allocating computations to physical devices. On these + * systems a large space of \p virtual_device_ids could be used at compile time, even though + * at runtime only a few physical devices will be present. + * + * The \p virtual_device_id may also be left unconstrained if not needed. + * + * Memory Scopes and Devices + * ------------------------- + * Multi-device systems can have complex memory hierarchies. For example + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCPU, 1, "llvm", "global") + * \endcode + * could denote: + * - The same memory area accessible from two separate CPUs without any CPU affinity; + * - Distinct memory areas in a NUMA architecture for which cross-device access is handled + * by the memory system; + * - Outright distinct memory areas, where one device cannot directly address the memory of + * another. + * + * Similarly: + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCUDA, 0, "cuda", "host") + * \endcode + * could denote the same memory area, but with very different access costs. + * + * Furthermore, not all memory scopes are accessible to all devices, and it is possible for + * a memory scope to only be accessible to a device when code is compiled with particular + * \p Target options. + * + * \p SEScopes themselves have no system-level understanding. Currently device planning will + * simply insert "device_copy" operators wherever \p SEScopes are not exactly pointwise equal. + * We may revisit this in the future as the work on memory pools matures. + * + * Joining and Defaulting + * ---------------------- + * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees + * with both join arguments. Eg: + * \code + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) + * => (kDLCPU, 3, "llvm", "global") + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "local)) + * => null (no join possible) + * \endcode + * + * Related to 'join' is 'default', which only takes constrained fields from the rhs when the + * lhs is unconstrained: + * \code + * Default(kDLCPU, -1, "llvm", "local"), (kDLCPU, 3, null, "global")) + * => (kDLCPU, 3, "llvm", "local") + * \endcode + * + * These operations are needed during device planning. + * + */ +class SEScopeNode : public AttrsNode { + public: + /*! + * \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is + * known then this will be equal to \p target->kind->device_type. If \p target is null then the + * target is to be determined later. + * + * This is needed to support the legacy "on_device" and "device_copy" calls which only allow + * a \p DLDeviceTypes (as an integer) to be given. + * + * kInvalidDeviceType denotes unconstrained. + */ + int device_type_int; + + DLDeviceType device_type() const { return static_cast(device_type_int); } + + /*! + * \brief The device identifier for the virtual device. This must be resolved to a physical + * device identifier either during compilation or at runtime. + * + * -1 denotes unconstrained. + */ + int virtual_device_id; + + /*! + * \brief The \p Target describing how to compile for the virtual device. + * + * Null denotes unconstrained. Note that if a target later becomes known for this \p SEScope + * then it must be consistent with the \p device_type if already known. This is enforced by the + * Join and Default methods. + */ + Target target; + + /*! + * \brief The scope of memory w.r.t. the virtual device which holds data. + * + * Empty denotes unconstrained. + */ + MemoryScope memory_scope; + + /*! + * \brief Returns true if scope is fully unconstrained, ie no target/device type, device id + * or memory scope is specified. + */ + bool IsFullyUnconstrained() const { + return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 && + memory_scope.empty(); + } + + /*! + * \brief Returns true if scope is fully constrained, ie target, device id and memory scope are + * all specified. + */ + bool IsFullyConstrained() const { + return target.defined() && virtual_device_id != -1 && !memory_scope.empty(); + } + + /*! + * \brief Returns the (virtual) \p Device implied by this \p SEScope. Both the \p device_type and + * \p virtual_device_must be constrained. The returned \p Device may not correspond to any + * physical device available at compile time or even runtime: see "Virtual vs Physical Devices" + * above. + */ + Device ToDevice() const { + ICHECK(device_type() != kInvalidDeviceType); + ICHECK(virtual_device_id != -1); + Device device; + device.device_type = device_type(); + device.device_id = virtual_device_id; + return device; + } + + TVM_DECLARE_ATTRS(SEScopeNode, "SEScope") { + TVM_ATTR_FIELD(device_type_int) + .describe("The type of the virtual device.") + .set_default(kInvalidDeviceType); + TVM_ATTR_FIELD(virtual_device_id) + .describe("The device id of the virtual device.") + .set_default(-1); + TVM_ATTR_FIELD(target) + .describe("The target describing how to compile for the virtual device.") + .set_default(Target()); + TVM_ATTR_FIELD(memory_scope) + .describe("The area of memory w.r.t. the virtual device where data is stored.") + .set_default(""); + } + + friend class SEScope; +}; + +/*! + * \brief Managed reference class to \p SEScopeNode. + * + * \sa SEScopeNode. + */ +class SEScope : public ObjectRef { + public: + /*! + * \brief Construct an SEScope. + * \param device_type The device type for the virtual device, or kInvalidDeviceType if + * unconstrained. If \p target is defined then must match its \p target->kind->device_type. + * \param virtual_device_id The device id for the virtual device, or -1 if unconstrained. + * \param target The target describing how to compile for the virtual device, or null if + * unconstrained. + * \param memory_scope The memory scope w.r.t. the virtual device which holds data, or "" if + * unconstrained. + * \return The SEScope + */ + explicit SEScope(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique fully unconstrained \p SEScope. */ + static SEScope FullyUnconstrained(); + + /*! + * \brief Returns the \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * The target and memory scope will be unconstrained. + */ + static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + ICHECK_GT(device_type, 0); + return SEScope(device_type, virtual_device_id); + } + static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type), virtual_device_id); + } + static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type->value), virtual_device_id); + } + + /*! \brief Returns the \p SEScope for \p device. */ + static SEScope ForDevice(const Device& device) { + return ForDeviceType(device.device_type, device.device_id); + } + + /*! \brief Returns the \p SEScope for \p device and \p target. */ + static SEScope ForDeviceAndTarget(const Device& device, Target target) { + return SEScope(device.device_type, device.device_id, std::move(target)); + } + + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ + TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, + MemoryScope memory_scope) { + return SEScope(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); + } + + /*! + * \brief Returns the 'join' of \p lhs and \p rhs. The result will agree pointwise with + * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such + * join exists, ie there's disagreement on at least one constrained field. + */ + static Optional Join(const SEScope& lhs, const SEScope& rhs); + + /*! + * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any + * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. + */ + static SEScope Default(const SEScope& lhs, const SEScope& rhs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + + friend class SEScopeCache; // Private implementation helper. +}; + +/*! + * \brief A cache of \p SEScopes. This can be used: + * - To avoid ending up with lots of identical instances, since the space of SEScopes for any + * one compilation is very small but the number of points they need to be constructed can + * be very large (eg during device planning). + * - So we can assume \p SEScopes are pointer equal if and only if they are structurally equal. + * This simplifies the unification of 'device domains' which are built on \p SEScopes. + */ +class SEScopeCache { + public: + /*! \brief Returns the unique \p SEScope representing given fields. */ + SEScope Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique \p SEScope structurally equal to the given \p se_scope. */ + SEScope Unique(const SEScope& scope); + + private: + /*! \brief Already constructed SEScopes. */ + std::unordered_set cache_; +}; + +} // namespace tvm + +#endif // TVM_TARGET_SE_SCOPE_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64a1023158e1..e0d34c87dda7 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -110,7 +110,12 @@ class TargetNode : public Object { /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; + bool SEqualReduce(const TargetNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; + static constexpr const char* _type_key = "Target"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: @@ -179,6 +184,9 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; + +using TargetMap = Map; + /*! * \brief Check and update host field of the given legacy target and target host pair. * Note that this function is for legacy target api compatibility issue only, not @@ -187,22 +195,24 @@ class Target : public ObjectRef { * \param host The pointer to a Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Target* target, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with values being Target objects + * \param target_map The pointer to a Map objects with values being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with keys being Target objects + * \param ir_modules The pointer to a Map objects with keys being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1e906cb381d8..5dc95c3ae675 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,6 +71,8 @@ riscv_cpu, hexagon, ) +from .se_scope import make_se_scope +from .compilation_config import make_compilation_config from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/compilation_config.py b/python/tvm/target/compilation_config.py new file mode 100644 index 000000000000..2796ec4b5135 --- /dev/null +++ b/python/tvm/target/compilation_config.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating CompilationConfigs.""" +from . import _ffi_api + + +def make_compilation_config(ctxt, targets, host_target=None): + """Returns a CompilationConfig appropriate for targets and an optional host_target. + Currently intended just for unit tests and will be replaced by a Python CompilationConfig + class in the future. Note that targets must be a dictionary from IntImm objects to Targets + and we do not support any of the lighter-weight conventions used by the various build(...) + APIs.""" + return _ffi_api.MakeCompilationConfig(ctxt, targets, host_target) diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/se_scope.py new file mode 100644 index 000000000000..83df5ae3448a --- /dev/null +++ b/python/tvm/target/se_scope.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating SEScopes.""" +from . import _ffi_api + + +def make_se_scope(device, target=None, memory_scope=""): + return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ebd6566889dc..486799603354 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1437,8 +1437,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto tbl = tvm::ReflectionVTable::Global(); - auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + auto attr_obj = + tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ea97bb35a09f..9eca038e5c93 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -773,8 +773,6 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { printed_attr << "?"; } else if (auto str_obj = value.as()) { printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (const auto* on_device_attrs = value.as()) { - printed_attr << "device_type=" << on_device_attrs->device_type; } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { @@ -787,7 +785,7 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { } Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), true); + return PrintAttr(GetRef(op), /*meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3c9c35c4f254..7e5702296542 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -715,7 +715,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map targets = args[1]; + TargetMap targets = args[1]; init(mod, targets); }); } else if (name == "codegen") { @@ -758,7 +758,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: - void init(void* mod, Map tmp) { + void init(void* mod, TargetMap tmp) { tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7005e94c2411..78978e192f0e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -43,7 +43,6 @@ Pass LabelOps(); } namespace backend { -using TargetsMap = Map; using namespace tvm::relay::transform; /*! @@ -56,7 +55,7 @@ struct BuildOutput { }; struct ExecutorCodegen { - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { for (const auto& pair : targets) { VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); @@ -307,7 +306,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + IRModule Optimize(IRModule relay_module, const TargetMap& targets, const std::unordered_map& params) { targets_ = targets; // No target_host setup it seems. @@ -444,7 +443,7 @@ class RelayBuildModule : public runtime::ModuleNode { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Update all the targets in the targets_ TargetsMap + // Update all the targets in the targets_ TargetMap CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. @@ -540,7 +539,7 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! \brief target host device */ tvm::Target target_host_; /*! \brief parameters */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index debd669126c4..d32ded379688 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -619,7 +619,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; + TargetMap tmp = args[1]; tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index ed774eccd8dd..e1ed3d47d36d 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -578,6 +578,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } + Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Nothing to lower inside primitive functions. + return GetRef(function_node); + } else { + return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node); + } + } + Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { Call call = GetRef(call_node); // Look for (indirect) calls to primitives. @@ -620,6 +629,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Replace with direct call to lowered primitive, and attach annotations to record calling // convention. + // =====> in new call_lowered form return Call(pair.first, args, pair.second); } diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e3b7d46457ad..d0401e9605f7 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -148,7 +148,7 @@ void UpdateFunctionMetadata(Function relay_func, * \param dev_type * \return Target */ -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets); /*! * \brief Update the "main" control function's metadata diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a647aa1a3fd2..f89a099b0d4f 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -437,7 +437,7 @@ inline bool IsAutoSchedulerEnabled() { * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPassPrefix(const Map& targets, bool is_vm); +Array GetPassPrefix(const TargetMap& targets, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 6a085adad3d1..c4c50c6c5646 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,7 +80,6 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; -using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -251,7 +250,7 @@ int GetFallbackDevice() { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), @@ -464,7 +463,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -540,7 +539,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } } - CCacheKey key(func, target); + tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering @@ -910,7 +909,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, + const tvm::Target& target_host) { exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -976,7 +976,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { +transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; @@ -1022,10 +1022,10 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, const Target& target_host_arg) { VLOG_CONTEXT << "VMCompiler::OptimizeModule"; - TargetsMap targets = targets_arg; + TargetMap targets = targets_arg; Target target_host = target_host_arg; CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index af3c5bccbeff..5b51d7821d78 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; -using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Target devices. */ - TargetsMap targets_; + TargetMap targets_; /*! \brief Target host device. */ tvm::Target target_host_; /*! \brief Global shared meta data */ diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc new file mode 100644 index 000000000000..b3491d656625 --- /dev/null +++ b/src/target/compilation_config.cc @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.cc + * \brief Implementation of \p CompilationConfig for collecting \p Targets. + */ + +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(CompilationConfigNode); + +void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { + v->Visit("legacy_target_map", &legacy_target_map); + v->Visit("host_target", &host_target); + v->Visit("primitive_targets", &primitive_targets); + v->Visit("default_primitive_se_scope", &default_primitive_se_scope); + v->Visit("host_se_scope", &host_se_scope); + v->Visit("optional_homogenous_target", &optional_homogeneous_target); + // NOTE: The se_scope_cache_ is not accessible via FFI. +} + +SEScope CompilationConfigNode::CanonicalSEScope(const SEScope& se_scope) const { + if (se_scope->target.defined()) { + return se_scope_cache_.Unique(se_scope); + } + DLDeviceType device_type = se_scope->device_type(); + // TODO(mbs): Proper diagnostics. + CHECK(device_type != kInvalidDeviceType) + << "SEScope annotations must include at least a device_type"; + Target target = FindPrimitiveTargetOrFail(se_scope->device_type()); + return se_scope_cache_.Unique( + SEScope(device_type, se_scope->virtual_device_id, target, se_scope->memory_scope)); +} + +void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContext& pass_ctx) { + // + // Gather the hints as to what our default device type for the 'host' should be, and + // create an appropriate target if we don't already have one. + // + DLDeviceType host_device_type; + if (host_target.defined()) { + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + if (host_device_type != kDLCPU) { + LOG(WARNING) << "Using the given host target '" << host_target << "' of non-CPU device type " + << host_device_type << " for all host operations and data"; + } else { + LOG(INFO) << "Using the given host target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } + } else if (primitive_targets.size() == 1 && + primitive_targets.front()->kind->device_type == kDLCPU) { + // In the homogenous case without an explicit host target just use the given target so long as + // it's a CPU. However make sure we 'forget' any host it may already have. + host_device_type = kDLCPU; + host_target = Target(primitive_targets.front()); + LOG(INFO) << "Using the unique target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } else { + // Fallback. + host_device_type = kDLCPU; + // Even if the list of available targets already includes one for kDLCPU we won't use it + // since its options may not be appropriate for host code (eg shape functions). Instead, + // create a fresh default Target. + host_target = MakeDefaultTarget(host_device_type); + LOG(WARNING) << "Using the default host target '" << host_target << "' of device type " + << host_device_type << " for all host operations and data"; + } + ICHECK(host_target.defined()); + + // + // Establish the host SEScope. + // + host_se_scope = se_scope_cache_.Unique(SEScope(host_device_type, + /*virtual_device_id=*/0, host_target)); + + // + // Now that we've settled on a host, make sure all the primitive Targets agree on it for + // their 'host' field. This mutates the primitives. + // + Array new_primitve_targets; + new_primitve_targets.reserve(primitive_targets.size()); + for (const auto& primitive_target : primitive_targets) { + new_primitve_targets.push_back(Target(primitive_target, host_target)); + } + primitive_targets = new_primitve_targets; + + // + // Gather the hints as to what our default device type for primitives should be. + // + DLDeviceType default_primitive_device_type; + Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + const int64_t v = opt_fallback_dev.value()->value; + if (v <= 0) { + LOG(FATAL) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " + << v; + default_primitive_device_type = kDLCPU; + } else { + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + } else if (primitive_targets.size() == 1) { + // In the homogeneous case there's no free choice. + default_primitive_device_type = + static_cast(primitive_targets.front()->kind->device_type); + LOG(INFO) << "Using the unique target '" << primitive_targets.front() << "' of device type " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } else { + // Fallback. Note that we'll require a primitive Target of kDLCPU device_type to be given + // and won't manufacture one out of thin air. + default_primitive_device_type = kDLCPU; + LOG(WARNING) << "Using " << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + + // + // Establish the default primitive SEScope, choosing a known Target to match the device type. + // + default_primitive_se_scope = se_scope_cache_.Unique( + SEScope(default_primitive_device_type, + /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); +} + +/* static */ Target CompilationConfigNode::MakeDefaultTarget(DLDeviceType device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") { + if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { + // LLVM is available. + return Target("llvm"); + } else { + // LLVM is not available. + return Target("stackvm"); + } + } else { + return Target(name); + } +} + +Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type) const { + auto itr = std::find_if( + primitive_targets.begin(), primitive_targets.end(), + [device_type](const Target& target) { return target->kind->device_type == device_type; }); + CHECK(itr != primitive_targets.end()) << "No target for device type " << device_type << " in the " + << primitive_targets.size() << " given by the targets list"; + return *itr; +} + +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + TargetMap legacy_target_map_arg, + Target optional_host_target_arg) { + VLOG_CONTEXT << "CompilationConfig"; + + auto node = make_object(); + + for (const auto& pair : legacy_target_map_arg) { + VLOG(0) << "Available primitive target " << pair.first << " = '" << pair.second << "'"; + } + if (optional_host_target_arg.defined()) { + VLOG(0) << "Available host target '" << optional_host_target_arg << "'"; + } + + // Capture the arguments in our representation. + for (const auto& pair : legacy_target_map_arg) { + node->primitive_targets.push_back(pair.second); + } + node->host_target = optional_host_target_arg; + + // Complete the targets vector and establish default scopes. After this primitive_targets will + // contain the definitive list of all required targets, target_host will be defined, and + // all primitive targets will have host target_host. + node->EstablishDefaultSEScopes(pass_ctx); + + // LEGACY: Reconstruct the target map with all the primitive targets. + for (const auto& primitive_target : node->primitive_targets) { + node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); + } + + ICHECK(node->default_primitive_se_scope->target.defined()); + ICHECK(node->host_se_scope->target.defined()); + ICHECK_GT(node->primitive_targets.size(), 0U); + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. Make this easy to detect. + node->optional_homogeneous_target = + node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); + + for (const auto& target : node->primitive_targets) { + LOG(INFO) << "Target '" << target << "' of device type " << target->kind->device_type + << " is available for primitives"; + } + LOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; + LOG(INFO) << "Using host scope " << node->host_se_scope; + + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("target.MakeCompilationConfig") + .set_body_typed([](const transform::PassContext& pass_ctx, TargetMap legacy_target_map, + Target optional_host_target) -> CompilationConfig { + return CompilationConfig(pass_ctx, std::move(legacy_target_map), + std::move(optional_host_target)); + }); + +} // namespace tvm diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc new file mode 100644 index 000000000000..150a883cb565 --- /dev/null +++ b/src/target/se_scope.cc @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.cc + * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + */ +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(SEScopeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "SEScope("; + if (node->IsFullyUnconstrained()) { + p->stream << "?"; + } else { + bool need_sep = false; + if (node->device_type() != kInvalidDeviceType) { + p->stream << "device_type=" << node->device_type(); + need_sep = true; + } + if (node->virtual_device_id >= 0) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "virtual_device_id=" << node->virtual_device_id; + need_sep = true; + } + if (node->target.defined()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "target='" << node->target << "'"; + need_sep = true; + } + if (!node->memory_scope.empty()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "memory_scope='" << node->memory_scope << "'"; + } + } + p->stream << ")"; + }); + +SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + ICHECK(!target.defined() || device_type == target->kind->device_type) + << "target '" << target << "' has device type " << target->kind->device_type + << " but scope has device type " << device_type; + auto node = make_object(); + node->device_type_int = device_type; + node->virtual_device_id = virtual_device_id; + node->target = std::move(target); + node->memory_scope = std::move(memory_scope); + data_ = std::move(node); +} + +/* static */ SEScope SEScope::FullyUnconstrained() { + static const SEScope unconstrained{}; + return unconstrained; +} + +/* static */ +Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType joined_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + joined_device_type = lhs->device_type(); + if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) { + return {}; + } + } else { + joined_device_type = rhs->device_type(); + } + int joined_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + joined_virtual_device_id = lhs->virtual_device_id; + if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) { + return {}; + } + } else { + joined_virtual_device_id = rhs->virtual_device_id; + } + Target joined_target; + if (lhs->target.defined()) { + joined_target = lhs->target; + if (rhs->target.defined() && lhs->target != rhs->target) { + return {}; + } + } else { + joined_target = rhs->target; + } + MemoryScope joined_memory_scope; + if (!lhs->memory_scope.empty()) { + joined_memory_scope = lhs->memory_scope; + if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { + return {}; + } + } else { + joined_memory_scope = rhs->memory_scope; + } + return SEScope(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope); +} + +/* static */ +SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType defaulted_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + defaulted_device_type = lhs->device_type(); + } else { + defaulted_device_type = rhs->device_type(); + } + int defaulted_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + defaulted_virtual_device_id = lhs->virtual_device_id; + } else { + defaulted_virtual_device_id = rhs->virtual_device_id; + } + Target defaulted_target; + if (lhs->target.defined()) { + defaulted_target = lhs->target; + } else { + // We can only default to the rhs's target if it is consistent with the device type + if (rhs->target.defined() && rhs->target->kind->device_type == defaulted_device_type) { + defaulted_target = rhs->target; + } + // else: leave as null + } + MemoryScope defaulted_memory_scope; + if (!lhs->memory_scope.empty()) { + defaulted_memory_scope = lhs->memory_scope; + } else { + defaulted_memory_scope = rhs->memory_scope; + } + return SEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); +} + +SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); + auto itr = cache_.find(prototype); + if (itr == cache_.end()) { + VLOG(1) << "added new scope " << prototype; + cache_.emplace(prototype); + return prototype; + } else { + VLOG(1) << "reusing '" << *itr << "' for '" << prototype << "'"; + ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); + if (prototype->target.defined()) { + ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); + } + return *itr; + } +} + +SEScope SEScopeCache::Unique(const SEScope& scope) { + return Make(scope->device_type(), scope->virtual_device_id, scope->target, scope->memory_scope); +} + +TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") + .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); + +} // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d7..d1c85c583b3b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -74,7 +74,7 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { +void CheckAndUpdateHostConsistency(TargetMap* targets, Target* host) { Map new_targets; for (auto& it : *targets) { auto target = it.second; @@ -531,6 +531,19 @@ Optional TargetNode::GetHost() const { return GetRef>(this->host.as()); } +bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { + return equal(kind.get(), other->kind.get()) && equal(host, other->host) && + equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); +} + +void TargetNode::SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(kind.get()); + hash_reduce(host); + hash_reduce(tag); + hash_reduce(keys); + hash_reduce(attrs); +} + /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc new file mode 100644 index 000000000000..5c2b7990a498 --- /dev/null +++ b/tests/cpp/target/compilation_config_test.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +Target TestCpuTarget() { return Target("llvm -mcpu arm64"); } + +Target TestCudaTarget() { return Target("nvidia/tesla-p40"); } + +Target TestDefaultCpuTarget() { return Target("llvm"); } + +Target TestExtDevTarget() { return Target("ext_dev"); } + +CompilationConfig TestCompilationConfig() { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), TestCpuTarget()); + return CompilationConfig(pass_ctx, legacy_target_map, TestDefaultCpuTarget()); +} + +TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 1); + EXPECT_TRUE(StructuralEqual()((*config->legacy_target_map.begin()).second, + Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 1); + EXPECT_TRUE( + StructuralEqual()(config->primitive_targets[0], Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cuda_target, host_target))); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, host_target); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 2); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_InvalidAttribute) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kInvalidDeviceType))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLMetal))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLExtDev)), TestExtDevTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); +} + +TEST(CompilationConfig, CanonicalSEScope) { + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + CompilationConfig config = TestCompilationConfig(); + + { + SEScope in = SEScope(kDLCPU); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cpu_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } + { + SEScope in = SEScope(kDLCUDA); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } +} + +TEST(CompilationConfig, CanonicalSEScope_NoDevice) { + CompilationConfig config = TestCompilationConfig(); + SEScope fully_unconstrained; + EXPECT_ANY_THROW(config->CanonicalSEScope(fully_unconstrained)); + SEScope missing_device(kInvalidDeviceType, 3, {}, "local"); + EXPECT_ANY_THROW(config->CanonicalSEScope(missing_device)); +} + +TEST(CompilationConfig, CanonicalSEScope_NoMatchingTarget) { + CompilationConfig config = TestCompilationConfig(); + SEScope no_such_target(kDLMetal); + EXPECT_ANY_THROW(config->CanonicalSEScope(no_such_target)); +} + +} // namespace +} // namespace tvm diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc new file mode 100644 index 000000000000..166ba46faf37 --- /dev/null +++ b/tests/cpp/target/se_scope_test.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +TEST(SEScope, Join_Defined) { + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, -1, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, target_a, "global"); + SEScope rhs = SEScope(kDLCUDA, 3); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCUDA, 2, target_a); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 2, target_a); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = rhs; + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } +} + +TEST(SEScope, Join_Undefined) { + { + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCPU); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, 4); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda")); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda")); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"), "local"); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"), "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } +} + +TEST(SEScope, Default) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, Target(), "global"); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "local"); + SEScope actual = SEScope::Default(lhs, rhs); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual, expected)); +} + +TEST(SEScope, Constructor_Invalid) { EXPECT_ANY_THROW(SEScope(kDLCPU, -1, Target("cuda"))); } + +TEST(SEScopeCache, Memoized) { + SEScopeCache cache; + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + SEScope se_scope_a = cache.Make(kDLCUDA, 3, target_a, "local"); + SEScope se_scope_b = cache.Make(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), se_scope_a); + EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), se_scope_b); + EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), se_scope_a); +} + +} // namespace +} // namespace tvm diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_se_scope.py new file mode 100644 index 000000000000..0a9384fa9c04 --- /dev/null +++ b/tests/python/target/test_se_scope.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm + + +def test_make_se_scope_for_device(): + se_scope = tvm.target.make_se_scope(tvm.device("cuda")) + assert se_scope.device_type == 2 + # ie kDLCUDA + assert se_scope.virtual_device_id == 0 + assert se_scope.target is None + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_and_target(): + target = tvm.target.Target("cuda") + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_target_and_memory_scope(): + target = tvm.target.Target("cuda") + scope = "local" + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == scope + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))