From 7f164826c6c2a14fa621220111dd7b397dd2b5bd Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 18 Oct 2021 17:34:51 -0700 Subject: [PATCH] Switch PlanDevices pass to be w.r.t. SEScopes instead of DLDeviceTypes. CAUTION: Breaking VM executable serialization change. I needed a new 'virtual devices' array in the executable so that instructions can continue to refer to devices by a simple index yet the VM can respect both the device type and id for runtime devices. Continuing from #9313, and as part of apache/tvm-rfcs#38, we switch PlanDevices to plan with respect to SEScopes instead of just DLDeviceTypes. Our ultimate goal is to be able to flow memory scopes between PrimFuncs by re-running PlanDevices after the LowerTE pass. This PR at least gets us to being able to flow the memory scopes, but the actual changes to PlanDevices to look inside PrimFuncs is still two PR's in the future. However, we get two nice side effects right away: - Since SEScopes contain Targets we can isolate all the device-to-target resolution machinery within PlanDevices (with the help of CompilationConfig). After PlanDevices has run we can retrieve the Target for any sub-expression directly from that sub-expression's SEScope. For now we retain the one-Target-per-DLDeviceType constraint since it baked into the public 'TargetMap' API, but the path to breaking that constraint is clearer. - Device ids are now respected all the way from annotation to executor. Previously though we had a bit of plumbing using Devices the device_id therein was ignored or defaulted to zero. The Python "on_device" annotation helpers still work w.r.t. devices. Thus though they now respect device ids, they do not allow the user to specify a Target or memory scope as supported by the underlying SEScope. --- include/tvm/ir/function.h | 12 +- include/tvm/relay/attrs/annotation.h | 65 +- include/tvm/relay/attrs/device_copy.h | 16 +- include/tvm/relay/attrs/memory.h | 7 +- include/tvm/relay/attrs/on_device.h | 101 +++ include/tvm/relay/transform.h | 24 +- include/tvm/runtime/vm/bytecode.h | 26 +- include/tvm/runtime/vm/executable.h | 44 +- include/tvm/runtime/vm/vm.h | 44 +- python/tvm/micro/contrib/stm32/emitter.py | 32 +- python/tvm/relay/op/annotation/annotation.py | 21 +- python/tvm/relay/op/tensor.py | 36 +- python/tvm/relay/transform/transform.py | 25 +- python/tvm/runtime/vm.py | 14 +- src/relay/backend/aot_executor_codegen.cc | 27 +- src/relay/backend/build_module.cc | 171 ++-- src/relay/backend/graph_executor_codegen.cc | 13 +- src/relay/backend/graph_plan_memory.cc | 54 +- src/relay/backend/interpreter.cc | 9 +- src/relay/backend/te_compiler.cc | 120 ++- src/relay/backend/te_compiler.h | 9 +- src/relay/backend/utils.cc | 71 +- src/relay/backend/utils.h | 19 +- src/relay/backend/vm/compiler.cc | 264 ++++--- src/relay/backend/vm/compiler.h | 20 +- src/relay/backend/vm/lambda_lift.cc | 9 +- src/relay/ir/expr_functor.cc | 13 +- src/relay/op/annotation/annotation.cc | 152 ---- src/relay/op/annotation/annotation.h | 106 --- src/relay/op/memory/device_copy.cc | 40 +- src/relay/op/memory/device_copy.h | 38 +- src/relay/op/memory/memory.cc | 9 +- src/relay/op/memory/memory.h | 4 +- src/relay/op/memory/on_device.cc | 167 ++++ src/relay/op/memory/on_device.h | 144 ++++ src/relay/transforms/device_aware_visitors.cc | 98 +-- src/relay/transforms/device_aware_visitors.h | 81 +- src/relay/transforms/device_domains.cc | 344 ++++---- src/relay/transforms/device_domains.h | 229 +++--- src/relay/transforms/device_planner.cc | 392 +++++----- src/relay/transforms/fold_constant.cc | 52 +- src/relay/transforms/memory_alloc.cc | 119 ++- src/relay/transforms/pass_utils.h | 1 + src/relay/transforms/to_a_normal_form.cc | 20 +- src/runtime/vm/bytecode.cc | 33 +- src/runtime/vm/executable.cc | 64 +- src/runtime/vm/profiler/vm.cc | 12 +- src/runtime/vm/serialize_utils.h | 12 +- src/runtime/vm/vm.cc | 102 ++- .../relay/transforms/device_domains_test.cc | 22 +- .../relay/op/annotation/test_annotation.py | 27 +- tests/python/relay/op/test_tensor.py | 52 ++ tests/python/relay/test_pass_plan_devices.py | 736 ++++++++++++------ tests/python/relay/test_vm.py | 57 +- .../unittest/test_auto_scheduler_measure.py | 17 +- .../test_micro_model_library_format.py | 6 +- .../python/unittest/test_runtime_profiling.py | 5 +- .../unittest/test_runtime_vm_profiler.py | 5 +- 58 files changed, 2452 insertions(+), 1960 deletions(-) create mode 100644 include/tvm/relay/attrs/on_device.h create mode 100644 src/relay/op/memory/on_device.cc create mode 100644 src/relay/op/memory/on_device.h create mode 100644 tests/python/relay/op/test_tensor.py diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 5ee719f9964f8..e466cde097ac1 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -191,24 +191,24 @@ constexpr const char* kTarget = "target"; constexpr const char* kGlobalSymbol = "global_symbol"; /*! - * \brief The device type which will hold each of the functions parameters. + * \brief The SEScope which will hold each of the functions parameters. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: Array (but interpreted as Array) + * Type: Array */ -constexpr const char* kParamDeviceTypes = "param_device_types"; +constexpr const char* kParamSEScopes = "param_se_scopes"; /*! - * \brief The device type which will hold the function result. + * \brief The SEScope which will hold the function result. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: Integer (but interpreted as DLDeviceType) + * Type: SEScope */ -constexpr const char* kResultDeviceType = "result_device_type"; +constexpr const char* kResultSEScope = "result_se_scope"; } // namespace attr } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 85ac3f36ff607..30839d725aab0 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -25,74 +25,13 @@ #define TVM_RELAY_ATTRS_ANNOTATION_H_ #include +#include #include namespace tvm { namespace relay { -/*! - * \brief Attributes for the "on_device" special operator. - * - * The Relay call (aka 'annotation'): - * \code - * on_device(sub_expr, device_type=2) - * \endcode - * constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2 - * (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be - * executed and stored on a different device. If so the compiler will automatically insert a - * "device_copy" call to mediate the transition between devices. - * - * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: - * \code - * multiply(on_device(add(%x, %y), device_type=2), %z) - * \endcode - * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. - * The compiler will rewrite this to: - * \code - * multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z) - * \endcode - * - * The Relay call - * \code - * on_device(sub_expr, device_type=2, is_fixed=True) - * \endcode - * is similar to the above, however the annotation itself must appear in an expression on the - * same device. The compiler will check the devices are consistent, and will not insert any - * "device_copy" call. This form of annotation shouldn't be necessary in user programs. However - * it is needed by the \p PlanDevices pass to fully specify the results of device planning so that - * the pass is idempotent. - * - * E.g.: The following program is equivalent to the above: - * \code - * let %a = on_device(add(%x, %y), device_type=2, is_fixed=True) - * multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z) - * \endcode - * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored - * on the GPU. - */ -struct OnDeviceAttrs : public tvm::AttrsNode { - // TODO(mbs): Replace device types with TargetDevice. - /*! \brief Device type on which argument expression should be evaluated. */ - int device_type = kInvalidDeviceType; - /*! - * \brief If true, the result device must also be \p device_type and device planning should - * not insert any "device_copy" calls to respect this annotation. - * - * This is used by the device planning pass itself when annotating the planned program. - */ - bool is_fixed = false; - - TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { - TVM_ATTR_FIELD(device_type) - .describe("The type of the virtual device which should hold the expression result.") - .set_default(0); - TVM_ATTR_FIELD(is_fixed) - .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") - .set_default(false); - } -}; - /*! * \brief Annotate an expression to be cast into specific data type. */ @@ -118,6 +57,8 @@ struct CompilerAttrs : public tvm::AttrsNode { /*! * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + * + * TODO(mbs): Replace with typed fields once attributes have stabilized. */ struct TIRCallAttrs : public tvm::AttrsNode { /*! \brief The metadata attached to the call node. */ diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index f7b0a04f45fa8..6d97ab79be4a2 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include +#include #include @@ -35,17 +36,14 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { - // TODO(mbs): Should be TargetDevice. - int dst_dev_type; - int src_dev_type; + SEScope src_se_scope = SEScope::FullyUnconstrained(); + SEScope dst_se_scope = SEScope::FullyUnconstrained(); TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { - TVM_ATTR_FIELD(src_dev_type) - .describe("The virtual device/context type where the op copies data from.") - .set_default(0); - TVM_ATTR_FIELD(dst_dev_type) - .describe("The virtual device/context type where the op copies data to.") - .set_default(0); + TVM_ATTR_FIELD(src_se_scope) + .describe("The (virtual) device and scope where the op copies data from."); + TVM_ATTR_FIELD(dst_se_scope) + .describe("The (virtual) device and scope where the op copies data to."); } }; diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index 85462c087cee0..952d4affc5849 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -42,15 +43,13 @@ Expr ToTupleType(const Type& t, const std::vector& exprs); */ struct AllocStorageAttrs : public tvm::AttrsNode { DataType dtype; - int device_id; - int device_type; + SEScope se_scope = SEScope::FullyUnconstrained(); TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) .describe("The dtype of the tensor to allocate.") .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); + TVM_ATTR_FIELD(se_scope).describe("The SEScope on which to allocate memory."); } }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h new file mode 100644 index 0000000000000..405926e209c69 --- /dev/null +++ b/include/tvm/relay/attrs/on_device.h @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/on_device.h + * \brief Attribute for the on device annotation. + */ +#ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_ +#define TVM_RELAY_ATTRS_ON_DEVICE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Attributes for the "on_device" special operator. + * + * The Relay call (aka 'annotation'): + * \code + * on_device(sub_expr, se_scope=S) + * \endcode + * constrains \p sub_expr to execute and store its result on the \p SEScope \p S. + * However the annotation itself may appear in an expression to be executed and stored on a + * different \p SEScope. If so the compiler will automatically insert a "device_copy" call to + * mediate the transition between \p SEScopes. + * + * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: + * \code + * multiply(on_device(add(%x, %y), se_scope=GPU), %z) + * \endcode + * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. + * The compiler will rewrite this to: + * \code + * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z) + * \endcode + * + * The Relay call + * \code + * on_device(sub_expr, se_scope=S, is_fixed=True) + * \endcode + * is similar to the above, however the annotation itself must appear in an expression on the + * same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not + * insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs. + * However it is needed by the \p PlanDevices pass to fully specify the results of device planning + * so that the pass is idempotent. + * + * E.g.: The following program is equivalent to the above: + * \code + * let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True) + * multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z) + * \endcode + * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored + * on the GPU. + */ +struct OnDeviceAttrs : public tvm::AttrsNode { + /*! + * \brief (Virtual) \p SEScope on which the result of the argument expression should be stored. + */ + SEScope se_scope = SEScope::FullyUnconstrained(); + /*! + * \brief If true, the result \p SEScope must also be \p se_scope, and device planning should + * not insert any "device_copy" calls to respect this annotation. + * + * This is used by the device planning pass itself when annotating the planned program. + */ + bool is_fixed = false; + + TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { + TVM_ATTR_FIELD(se_scope) + .describe("The (virtual) device and scope holding the expression result.") + .set_default(SEScope::FullyUnconstrained()); + TVM_ATTR_FIELD(is_fixed) + .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + .set_default(false); + } +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ATTRS_ON_DEVICE_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index e740776d6d4f4..aa9d3b41554c5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -437,23 +439,27 @@ TVM_DLL Pass RelayToTIRTargetHook(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param target_host The target used by the host for compilation. - * \param targets The device type and target pairs for compilation. + * \param cpu_se_scope SEScope for computations and data which must reside on a CPU, such as + * shapes and shape functions. * * \return The pass. */ -TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +TVM_DLL Pass ManifestAlloc(SEScope cpu_se_scope); /*! - * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which - * every Relay sub-expression should run (and the result stored). Captures the result of that - * analysis using new "on_device" and "device_copy" CallNodes. See - * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p SEScope on which + * every Relay sub-expression should run and the result stored. Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. + * + * See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} * for help recovering the device for an arbitrary sub-expression in downstream transformations. * - * \param default_device_type DLDeviceType for default device. + * \param config Describes the targets and default \p SEScope for all primitive operators and + * host sub-expressions. + * + * \return The pass. */ -TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); +TVM_DLL Pass PlanDevices(CompilationConfig config); } // namespace transform diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 72a557fa93b1e..a2a64d76ce869 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -176,6 +176,7 @@ struct Instruction { RegName object; } get_tag; struct /* AllocADT Operands */ { + // TODO(mbs): Needs a DeviceAndScope. /*! \brief The datatype's constructor tag. */ Index constructor_tag; /*! \brief The number of fields to store in the datatype. */ @@ -184,6 +185,7 @@ struct Instruction { RegName* datatype_fields; }; struct /* AllocClosure Operands */ { + // TODO(mbs): Needs a DeviceAndScope. /*! \brief The index into the function table. */ Index clo_index; /*! \brief The number of free variables to capture. */ @@ -198,8 +200,8 @@ struct Instruction { Index alignment; /*! \brief The hint of the dtype. */ DLDataType dtype_hint; - /*! \brief The device type of the allocation. */ - Index device_type; + /*! \brief The index of the device on which the allocation will be made. */ + Index device_index; } alloc_storage; struct /* ShapeOf Operands */ { RegName tensor; @@ -210,11 +212,11 @@ struct Instruction { } reshape_tensor; struct /* DeviceCopy Operands */ { RegName src; - /*! \brief The source device type. */ - Index src_device_type; - /*! \brief The destination device type. */ - Index dst_device_type; - }; + /*! \brief The index of the source device to copy from. */ + Index src_device_index; + /*! \brief The index of the destination deviceto copy to. */ + Index dst_device_index; + } device_copy; }; /*! @@ -352,12 +354,12 @@ struct Instruction { * \param size The size of the allocation. * \param alignment The allocation's alignment. * \param dtype_hint The data type hint for the allocator. - * \param device_type The device type for the allocator. + * \param device_index The index of the device to allocate on. * \param dst The destination to place the storage. * \return The alloc storage instruction. */ static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - Index device_type, RegName dst); + Index device_index, RegName dst); /*! * \brief Get the shape of an input tensor. * \param tensor The input tensor. @@ -376,12 +378,12 @@ struct Instruction { /*! * \brief Copy tensor cross different devices. * \param src The source register. - * \param src_device_type The device type of the tensor for the source register. - * \param dst_device_type The device type of the tensor ofr the destination register. + * \param src_device_index The index of the device holding the tensor in the source register. + * \param dst_device_index The index of the device to hold the tensor in the destination register. * \param dst The destination register to store the copied tensor. * \return The device copy instruction. */ - static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + static Instruction DeviceCopy(RegName src, Index src_device_index, Index dst_device_index, RegName dst); Instruction(); diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 6e564fd623802..311667904df61 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -132,12 +132,18 @@ class Executable : public ModuleNode { std::string GetBytecode() const; /*! - * \brief Returns a description of all the contants in the executable in human-readable + * \brief Returns a description of all the constants in the executable in human-readable * format. Not intended to be machine readable, but rather to help with debugging and * diffing generated code. */ std::string GetConstants() const; + /*! + * \brief Returns a description of all the (virtual) devices in the executable in human-readable + * format. + */ + std::string GetVirtualDevices() const; + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. @@ -183,6 +189,16 @@ class Executable : public ModuleNode { const char* type_key() const final { return "VMExecutable"; } + /*! + * \brief The (compile-time, virtual) devices corresponding to each device index. + * Currently we only support at most one device per device type. + */ + std::vector virtual_devices; + /*! + * \brief The device index corresponding to the 'host' device. That will hold and evaluate + * shape-related data and code. + */ + int host_device_index = -1; /*! \brief The global constant pool. */ std::vector constants; /*! \brief A map from globals (as strings) to their index in the function map. */ @@ -195,38 +211,52 @@ class Executable : public ModuleNode { std::map> op_attrs; /*! \brief The virtual machine's function table. */ std::vector functions; - /*! \brief The device type for each constant. */ - std::vector const_device_type; + /*! \brief The index of the device holding each constant. */ + std::vector const_device_indexes; private: + /*! + * \brief Save the virtual devices + * + * /param strm The output stream. + */ + void SaveVirtualDevicesSection(dmlc::Stream* strm); + /*! * \brief Save the globals. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveGlobalSection(dmlc::Stream* strm); /*! * \brief Save the constant pool. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveConstantSection(dmlc::Stream* strm); /*! * \brief Save primitive op names. * - * \param strm The input stream. + * \param strm The output stream. */ void SavePrimitiveOpNames(dmlc::Stream* strm); /*! * \brief Save the vm functions. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Load the virtual devices + * + * /param strm The input stream. + */ + void LoadVirtualDevicesSection(dmlc::Stream* strm); + /*! * \brief Load the globals. * diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index ece73fcfda34d..604c97330d995 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -82,20 +82,20 @@ struct VMFunction { /*! \brief The instructions representing the function. */ std::vector instructions; /*! \brief The size of the frame for this function */ - Index register_file_size; - /*! \brief The device type of each parameter for this function. */ - std::vector params_device_type; - - VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, Index register_file_size, - const std::vector params_device_type = {}) - : name(name), - params(params), - instructions(instructions), + Index register_file_size = 0; + /*! \brief The indexes for the device holding each function parameter. */ + std::vector param_device_indexes; + + VMFunction(std::string name, std::vector params, + std::vector instructions, Index register_file_size, + std::vector param_device_indexes) + : name(std::move(name)), + params(std::move(params)), + instructions(std::move(instructions)), register_file_size(register_file_size), - params_device_type(params_device_type) {} + param_device_indexes(std::move(param_device_indexes)) {} - VMFunction() {} + VMFunction() = default; friend std::ostream& operator<<(std::ostream& os, const VMFunction&); }; @@ -239,17 +239,19 @@ class VirtualMachine : public runtime::ModuleNode { Index output_size, const std::vector& args); /*! - * \brief Initialize the virtual machine for a set of devices. - * \param devices The set of TVM devices. + * \brief Initialize the virtual machine for a set of (physical) devices. + * \param physical_devices The set of TVM devices. * \param alloc_types The allocator types for each device. */ - void Init(const std::vector& devices, const std::vector& alloc_types); + void Init(const std::vector& physical_devices, + const std::vector& alloc_types); /*! \brief Run VM dispatch loop. */ void RunLoop(); - /*! \brief Get device from the device list based on a given device type. */ - Device GetDevice(Index device_type) const; + /*! \brief Get device from the device list based on a given device index. */ + Device GetDevice(Index device_index) const; + Allocator* GetAllocator(Index device_index) const; /*! * \brief Invoke a global setting up the VM state to execute. @@ -301,9 +303,13 @@ class VirtualMachine : public runtime::ModuleNode { const Executable* exec_; /*! \brief The function name to inputs mapping. */ std::unordered_map> inputs_; - /*! \brief The set of TVM devices the VM is currently executing on. */ + /*! + * \brief The "physical" devices the VM can execute primitives on. All "device indexes" + * are w.r.t. this vector. Each entry in this vector must match the corresponding entry + * in the executable's "virtual" devices vector. + */ std::vector devices_; - /*! \brief The cached memory allocators. */ + /*! \brief The cached memory allocators, one per device. */ std::vector allocators_; /*! * \brief The constant pool for runtime. It caches the device dependent diff --git a/python/tvm/micro/contrib/stm32/emitter.py b/python/tvm/micro/contrib/stm32/emitter.py index 8453ea78e012a..aec5912871fd5 100644 --- a/python/tvm/micro/contrib/stm32/emitter.py +++ b/python/tvm/micro/contrib/stm32/emitter.py @@ -44,7 +44,7 @@ def _fix_name(node_name): - """ Replace ':' with '_' in names like 'InputImg:0' """ + """Replace ':' with '_' in names like 'InputImg:0'""" return node_name.replace(":", "_") @@ -116,7 +116,7 @@ def _get_tensor_size_bytes(dims, dltype): def _preprocess_code(src): - """ Hack the C code implementing the model. """ + """Hack the C code implementing the model.""" dst = "#include \n" "#include \n\n" dst = dst + src return dst @@ -193,7 +193,7 @@ def __init__(self, include_activations=True, include_inputs=True, include_output self._quantization = {} def _extract_quantization_info(self, quantization): - """ Build dictionary with quantization infos.""" + """Build dictionary with quantization infos.""" for dl_tensor_name in self._input_data: if dl_tensor_name in quantization: @@ -258,7 +258,7 @@ def _get_tensor_from_node(self, nid, idx): return tensor def _compute_data_placement(self): - """ Compute inputs, outputs, weight, activation sizes""" + """Compute inputs, outputs, weight, activation sizes""" self._inputs = self._arg_nodes.copy() @@ -548,7 +548,7 @@ def parse_module(self, module, quantization=None): self._parse_model(quantization) def _emit_params_data(self, name, out_h, out_c): - """ Emits the network_data[c,h] files with parameters.""" + """Emits the network_data[c,h] files with parameters.""" name_upper = name.upper() @@ -674,7 +674,7 @@ def _emit_open(self, name, out_h, out_c): ) def _emit_close(self, name, out_h, out_c): - """ Emits the ai_model_info structure. """ + """Emits the ai_model_info structure.""" name_upper = name.upper() @@ -794,7 +794,7 @@ def _emit_tensor_quant(self, dl_tensor_name, out_c): return None def _emit_tensor_init(self, dl_tensor_name, tensor, out_c): - """ Emits the tensor instantiation code. """ + """Emits the tensor instantiation code.""" dltype = tensor["dltype"] dims = tensor["dims"] @@ -838,7 +838,7 @@ def _emit_tensor_init(self, dl_tensor_name, tensor, out_c): def _emit_activation_buffers(self, name, out_c): # pylint: disable=unused-argument - """ Emits activation tensors, including inputs/outputs.""" + """Emits activation tensors, including inputs/outputs.""" out_c.write( textwrap.dedent( @@ -905,7 +905,7 @@ def _emit_activation_buffers(self, name, out_c): out_c.write(f"\n") def _emit_params_buffers(self, name, out_c): - """ Emits all parameter tensors.""" + """Emits all parameter tensors.""" out_c.write( textwrap.dedent( @@ -922,7 +922,7 @@ def _emit_params_buffers(self, name, out_c): out_c.write(f"\n") def _emit_network(self, name, out_c): - """ Emits prototypes for the network operator functions.""" + """Emits prototypes for the network operator functions.""" out_c.write( textwrap.dedent( @@ -967,7 +967,7 @@ def _emit_tensor_activation(self, dl_tensor_name, tensor, out_c): ) def _emit_activation_init(self, name, out_c): - """ Emits buffer initialization code for activation tensors.""" + """Emits buffer initialization code for activation tensors.""" out_c.write( textwrap.dedent( @@ -1015,7 +1015,7 @@ def _emit_activation_init(self, name, out_c): ) def _emit_params_init(self, name, out_c): - """ Emits buffer initialization code for params tensors.""" + """Emits buffer initialization code for params tensors.""" out_c.write( textwrap.dedent( @@ -1063,13 +1063,13 @@ def _emit_params_init(self, name, out_c): ) def _emit_init(self, name, out_c): - """ Emits buffer initialization code.""" + """Emits buffer initialization code.""" self._emit_activation_init(name, out_c) self._emit_params_init(name, out_c) def _emit_run(self, name, out_h, out_c): - """ Emits the run function code.""" + """Emits the run function code.""" out_h.write( textwrap.dedent( @@ -1230,7 +1230,7 @@ def _emit_run(self, name, out_h, out_c): out_c.write(f"\n") def _emit_create_destroy(self, name, out_h, out_c): - """ Emits the create/destroy functions.""" + """Emits the create/destroy functions.""" out_h.write( textwrap.dedent( @@ -1296,7 +1296,7 @@ def _emit_create_destroy(self, name, out_h, out_c): ) def emit_code(self, dest_dir, model_name): - """ Emits the C code implementing the model. """ + """Emits the C code implementing the model.""" # Build the directory structure if os.path.exists(dest_dir): diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index f5f8870ab0153..cf70dc6e267e5 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Annotation operations.""" +from tvm import target from tvm.runtime import ndarray as _nd from tvm.runtime import Device as _Device @@ -22,11 +23,11 @@ from .. import op as reg -def _device_to_int(device): +def _make_se_scope(device): if isinstance(device, _Device): - return device.device_type + return target.make_se_scope(device) if isinstance(device, str): - return _nd.device(device).device_type + return target.make_se_scope(_nd.device(device)) raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) @@ -39,7 +40,7 @@ def on_device(data, device, is_fixed=False): The expression to be annotated. device : Union[:py:class:`Device`, str] - The device to annotate with. Only the device's type is significant. + The device to annotate with. is_fixed : bool If false (the default), a device_copy @@ -52,7 +53,7 @@ def on_device(data, device, is_fixed=False): result : tvm.relay.Expr The annotated expression. """ - return _make.on_device(data, _device_to_int(device), is_fixed) + return _make.OnDevice(data, _make_se_scope(device), is_fixed) def function_on_device(function, param_devices, result_device): @@ -65,18 +66,18 @@ def function_on_device(function, param_devices, result_device): The function to be annotated. param_devices : Array[Union[:py:class:`Device`, str]] - The devices for each parameter. Only the device types are significant. + The devices for each parameter. result_device: Union[:py:class:`Device`, str] - The device for the function result. Only the device type is significant. + The device for the function result. Returns ------- - result : tvm.rleay.Function + result : tvm.relay.Function The annotated function. """ - return _make.function_on_device( - function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device) + return _make.FunctionOnDevice( + function, [_make_se_scope(d) for d in param_devices], _make_se_scope(result_device) ) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index e615bbf21b864..d9847a4535694 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -16,6 +16,7 @@ # under the License. """Basic tensor operations.""" # pylint: disable=redefined-builtin, unused-argument +from tvm import target from tvm.runtime import ndarray as _nd from tvm.runtime import Device as _Device from tvm.te.hybrid import script @@ -26,6 +27,14 @@ from . import op as reg +def _make_se_scope(device): + if isinstance(device, _Device): + return target.make_se_scope(device) + if isinstance(device, str): + return target.make_se_scope(_nd.device(device)) + raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) + + # We create a wrapper function for each operator in the # python side to call into the positional _make.OpName function. # @@ -1181,7 +1190,7 @@ def copy_shape_func(attrs, inputs, _): return [_copy_shape_func(inputs[0])] -def device_copy(data, src_dev, dst_dev): +def device_copy(data, src_device, dst_device): """Copy data from the source device to the destination device. This operator helps data transferring between difference devices for heterogeneous execution. @@ -1191,10 +1200,10 @@ def device_copy(data, src_dev, dst_dev): data : tvm.relay.Expr The tensor to be copied. - src_dev : Union[:py:class:`Device`, str] + src_device : Union[:py:class:`Device`, str] The source device where the data is copied from. - dst_dev : Union[:py:class:`Device`, str] + dst_device : Union[:py:class:`Device`, str] The destination device where the data is copied to. Returns @@ -1202,26 +1211,7 @@ def device_copy(data, src_dev, dst_dev): result : tvm.relay.Expr The copied result. """ - if isinstance(src_dev, _Device): - src_dev = src_dev.device_type - elif isinstance(src_dev, str): - src_dev = _nd.device(src_dev).device_type - else: - raise ValueError( - "src_dev is expected to be the type of Device or " - "str, but received %s" % (type(src_dev)) - ) - - if isinstance(dst_dev, _Device): - dst_dev = dst_dev.device_type - elif isinstance(dst_dev, str): - dst_dev = _nd.device(dst_dev).device_type - else: - raise ValueError( - "dst_dev is expected to be the type of Device or " - "str, but received %s" % (type(dst_dev)) - ) - return _make.device_copy(data, src_dev, dst_dev) + return _make.DeviceCopy(data, _make_se_scope(src_device), _make_se_scope(dst_device)) def shape_of(data, dtype="int32"): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0dc07944836dc..01473a82fb3ac 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -891,6 +891,7 @@ def __init__(self, *args, **kwargs): # initialize handle in cass pass_cls creation failed.fg self.handle = None inst = pass_cls(*args, **kwargs) + # it is important not to capture self to # avoid a cyclic dependency def _pass_func(func, mod, ctx): @@ -1146,14 +1147,26 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() -def PlanDevices(default_device): +def PlanDevices(config): """ - Uses existing "on_device" and "device_copy" CallNodes to infer the device on which - every Relay sub-expression should run (and the result stored). Captures the result of that - analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of - the default_device is ignored. + Uses existing "on_device" and "device_copy" CallNodes to infer the SEScope on which + every Relay sub-expression should run and the result stored. Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Sub-expressions which are + not otherwise constrained are assigned to the default_primitive_se_scope. However data and + computations which must be hosted on a CPU (such as shapes and shape functions) use the + cpu_se_scope. + + Parameters + ---------- + config : tvm.CompilationConfig + The compilation configuration, specifying available targets and default devices. + + Returns + ------- + ret : tvm.transforms.Pass + The pass. """ - return _ffi_api.PlanDevices(default_device) + return _ffi_api.PlanDevices(config) def FoldExplicitPadding(): diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index c1cbc966acdc6..365e38c6e06c0 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -72,6 +72,7 @@ def __init__(self, mod): self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] self._get_constants = self.mod["get_constants"] + self._get_virtual_devices = self.mod["get_virtual_devices"] self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] @@ -251,6 +252,11 @@ def constants(self): Useful for debugging and diffing generated executables in unit tests.""" return self._get_constants() + @property + def virtual_devices(self): + """Returns a human-readable description of all the (virtual) devices in the executable.""" + return self._get_virtual_devices() + @property def globals(self): """Get the globals used by the Relay VM executable. @@ -295,7 +301,8 @@ class VirtualMachine(object): The VM executable. device : tvm.runtime.Device or List[tvm.runtime.Device] - The device to deploy the module + The device(s) on which the model will run. + Currently at most one device per device type is supported. memory_cfg : str or Dict[tvm.runtime.Device, str], optional Config the type of memory allocator. The allocator type can be ["naive", @@ -363,10 +370,7 @@ def _setup_device(self, dev, memory_cfg): devs = dev if not isinstance(dev, (list, tuple)): if not isinstance(dev, tvm.runtime.Device): - raise TypeError( - "dev is expected to be Device or \ - List[Device]" - ) + raise TypeError("dev is expected to be Device or List[Device]") devs = [dev] # CPU is required for executing shape functions diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 7e5702296542b..a266f185e9962 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -109,18 +109,18 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void VisitExpr_(const TupleNode* op) final { std::vector storage_ids; - std::vector device_types; + std::vector se_scopes; std::vector storage_sizes_in_bytes; Expr expr = GetRef(op); for (Expr field : op->fields) { auto sid = GetStorage(field); storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); - device_types.insert(device_types.end(), sid->device_types.begin(), sid->device_types.end()); + se_scopes.insert(se_scopes.end(), sid->se_scopes.begin(), sid->se_scopes.end()); storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), sid->storage_sizes_in_bytes.begin(), sid->storage_sizes_in_bytes.end()); } - storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + storage_device_map_[expr] = StorageInfo(storage_ids, se_scopes, storage_sizes_in_bytes); AssignReturnSid(expr); } @@ -129,7 +129,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { auto sids = GetStorage(op->tuple); ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); storage_device_map_[expr] = - StorageInfo({sids->storage_ids[op->index]}, {sids->device_types[op->index]}, + StorageInfo({sids->storage_ids[op->index]}, {sids->se_scopes[op->index]}, {sids->storage_sizes_in_bytes[op->index]}); AssignReturnSid(expr); } @@ -163,7 +163,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \param prototype The prototype token. * \return The required memory size. * - * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc + * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc */ size_t GetMemorySizeBytes(const TensorType& ttype) { size_t size = 1; @@ -195,24 +195,25 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { */ void CreateStorage(const ExprNode* op) { Expr expr = GetRef(op); - return CreateStorage(expr, GetInScopeDeviceType(expr)); + return CreateStorage(expr, GetSEScope(expr)); } /*! - * \brief Create storage to hold the result of evaluating \p expr on \p device_type. + * \brief Create storage to hold the result of evaluating \p expr in \p se_scope. */ - void CreateStorage(const Expr& expr, DLDeviceType device_type) { - ICHECK(device_type != kInvalidDeviceType) << "invalid device type for expr:" << std::endl + void CreateStorage(const Expr& expr, SEScope se_scope) { + ICHECK(!se_scope->IsFullyUnconstrained()) << "invalid SEScope for expr:" << std::endl << PrettyPrint(expr); std::vector storage_ids; - std::vector device_types; + std::vector se_scopes; std::vector storage_sizes_in_bytes; for (const auto& ttype : FlattenTupleType(expr->checked_type())) { storage_ids.push_back(next_available_sid_++); - device_types.push_back(device_type); + se_scopes.push_back(se_scope); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } - storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(se_scopes), + std::move(storage_sizes_in_bytes)); } /*! \brief mapping of expression -> storageInfo */ @@ -589,7 +590,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { mod = WithAttr(mod, "main_func_info", func_info); } - IRModule lowered_mod = tec::LowerTEPass(targets_, mod_name, [this](Function func) { + IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 78978e192f0e4..cd9c7d68366d1 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include @@ -161,6 +162,8 @@ std::unique_ptr MakeExecutorCodegen(String executor_str) { */ class RelayBuildModule : public runtime::ModuleNode { public: + RelayBuildModule() = default; + /*! * \brief Get member function to front-end * \param name The name of the function. @@ -207,7 +210,7 @@ class RelayBuildModule : public runtime::ModuleNode { } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2); - *rv = this->Optimize(args[0], args[1], this->params_); + *rv = this->Optimize(args[0], args[1]); }); } else { LOG(FATAL) << "Unknown packed function: " << name; @@ -274,26 +277,16 @@ class RelayBuildModule : public runtime::ModuleNode { * \brief Build relay IRModule for graph executor * * \param mod Relay IRModule - * \param target Target device + * \param targets Target devices * \param target_host Host target device */ void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { - for (const auto& pair : targets) { - VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); - } - if (target_host.defined()) { - VLOG(0) << "Build target_host = " << target_host->str(); - } - VLOG(0) << "Build executor = '" << executor << "'"; - VLOG(0) << "Build mod_name = '" << mod_name << "'"; - - // Create protected variable targets_ from ground up - targets_ = targets; - target_host_ = target_host; + VLOG_CONTEXT << "Build"; executor_ = executor; - CheckAndUpdateHostConsistency(&targets_, &target_host_); - BuildRelay(mod, params_, mod_name); + config_ = CompilationConfig(PassContext::Current(), targets, target_host); + + BuildRelay(std::move(mod), mod_name); } protected: @@ -302,95 +295,58 @@ class RelayBuildModule : public runtime::ModuleNode { * * \param relay_module The input IRModule where optmization will be applied on. * \param targets The device type to `Target` mapping. - * \param params The param name to value mapping. * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetMap& targets, - const std::unordered_map& params) { - targets_ = targets; - // No target_host setup it seems. - return OptimizeImpl(relay_module, params); + IRModule Optimize(IRModule relay_module, const TargetMap& targets) { + VLOG_CONTEXT << "Optimize"; + // TODO(mbs): executor_ will be whatever was left over from last Build. Note that + // the empty executor string will CHECK fail, so how are folks using this API? + config_ = CompilationConfig(transform::PassContext::Current(), targets, + /*optional_host_target=*/Target()); + return OptimizeImpl(std::move(relay_module)); } - IRModule OptimizeImpl(IRModule relay_module, - const std::unordered_map& params) { + IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (params.size()) { + if (!params_.empty()) { ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params); + auto new_main = BindParamsByName(main_func, params_); IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs = GetPassPrefix(targets_, false); + Array pass_seqs = GetPassPrefix( + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); - // TODO(mbs): Centralize this logic and reconcile with similar in relay/backend/vm/compiler.cc - DLDeviceType default_device_type; - if (targets_.size() == 1) { - // Homogenous execution. - default_device_type = static_cast((*targets_.begin()).first->value); - const auto& target = (*targets_.begin()).second; - - // This pass currently only supports the homogeneous case. - pass_seqs.push_back( - transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); - } else { - // Heterogeneous execution. - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type"); - if (opt_fallback_dev) { - default_device_type = static_cast(opt_fallback_dev.value()->value); - Integer integer(static_cast(default_device_type)); - CHECK_GT(default_device_type, 0U) - << "The 'relay.fallback_device_type' is set to an invalid device type."; - if (targets_.count(integer) == 0) { - LOG(WARNING) - << "The 'relay.fallback_device_type' has been set to " << default_device_type - << " however no target has been given for that device type in the targets map. " - "Creating an appropriate default target."; - targets_.Set(integer, CreateDefaultTarget(default_device_type)); - } - } else { - default_device_type = kDLCPU; - Integer integer(static_cast(default_device_type)); - if (targets_.count(integer) == 0) { - LOG(WARNING) << "Using the default device type of kDLCPU, however no target has been " - "given for that device type in the targets map. Creating an appropriate " - "default target."; - targets_.Set(integer, CreateDefaultTarget(default_device_type)); - } - } - } - // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // hetrogenous execution. - pass_seqs.push_back(transform::PlanDevices(default_device_type)); + pass_seqs.push_back(transform::PlanDevices(config_)); // Fuse the operations if it is needed. pass_seqs.push_back(transform::FuseOps()); // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); - if (targets_.size() == 1) { - With tctx((*targets_.begin()).second); + if (config_->optional_homogeneous_target.defined()) { + With tctx(config_->optional_homogeneous_target); relay_module = seq(relay_module); } else { relay_module = seq(relay_module); } // Do layout rewrite for auto-scheduler. - if (backend::IsAutoSchedulerEnabled() && targets_.size() == 1) { - const auto& target = (*targets_.begin()).second; + if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { Pass major_pass = transform::AutoSchedulerLayoutRewrite(); bool enable_layout_rewrite_targets = - target->kind->device_type == kDLCPU || target->GetAttr("device", "") == "mali"; + config_->optional_homogeneous_target->kind->device_type == kDLCPU || + config_->optional_homogeneous_target->GetAttr("device", "") == "mali"; if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { - With tctx(target); + With tctx(config_->optional_homogeneous_target); relay_module = major_pass(relay_module); // Defuse ops to fold constants, then fuse them again relay_module = transform::DefuseOps()(relay_module); @@ -416,45 +372,22 @@ class RelayBuildModule : public runtime::ModuleNode { return relay_module; } - /*! - * \brief Returns a default target to represent \p device_type. - */ - static Target CreateDefaultTarget(DLDeviceType device_type) { - std::string name = runtime::DeviceName(device_type); - if (name == "cpu") { - return Target("llvm"); - } else { - return Target(name); - } - } - /*! * \brief Compile a Relay IR module to runtime module. * * \param relay_module The Relay IR module. * \param params The parameters. */ - void BuildRelay(IRModule relay_module, - const std::unordered_map& params, - const String mod_name) { - Target target_host = GetTargetHost(); - // If no target_host has been set, we choose a default one, which is - // llvm if "codegen.LLVMModuleCreate" is accessible. - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - - // Update all the targets in the targets_ TargetMap - CheckAndUpdateHostConsistency(&targets_, &target_host); - + void BuildRelay(IRModule relay_module, const String& mod_name) { // Relay IRModule -> IRModule optimizations. - relay_module = OptimizeImpl(relay_module, params); + relay_module = OptimizeImpl(std::move(relay_module)); // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_); - executor_codegen_->Init(nullptr, targets_); + executor_codegen_->Init(nullptr, config_->legacy_target_map); executor_codegen_->Codegen(func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); @@ -467,9 +400,12 @@ class RelayBuildModule : public runtime::ModuleNode { lowered_funcs.Set(ext_dev, IRModule()); } + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + // Generate a placeholder function that attaches linked params as its arguments. - if (target_host->GetAttr("link-params").value_or(Bool(false))) { - CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; + const Target& host_target = config_->host_se_scope->target; + if (host_target->GetAttr("link-params").value_or(Bool(false))) { + CHECK(pf != nullptr) << "Unable to link-params without llvm codegen."; auto param_ids = executor_codegen_->GetParamIds(); auto link_params = Map(); for (auto param : ret_.params) { @@ -482,18 +418,19 @@ class RelayBuildModule : public runtime::ModuleNode { DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); - if (lowered_funcs.find(target_host) == lowered_funcs.end()) { - lowered_funcs.Set(target_host, IRModule(Map({}))); + if (lowered_funcs.find(host_target) == lowered_funcs.end()) { + lowered_funcs.Set(host_target, IRModule(Map({}))); } - lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), + lowered_funcs[host_target]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); } // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { - if (target_host.defined() && target_host->kind->name == "llvm") { + if (host_target->kind->name == "llvm") { + CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen."; // If we can decide the target is LLVM, we then create an empty LLVM module. - ret_.mod = (*pf)(target_host->str(), "empty_module"); + ret_.mod = (*pf)(host_target->str(), "empty_module"); } else { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining @@ -501,11 +438,11 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array{}); } } else { - ret_.mod = tvm::build(lowered_funcs, target_host_); + ret_.mod = tvm::build(lowered_funcs, host_target); } auto ext_mods = executor_codegen_->GetExternalModules(); - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, executor_codegen_->GetMetadata()); // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { @@ -522,26 +459,8 @@ class RelayBuildModule : public runtime::ModuleNode { } } - private: - Target GetTargetHost() { - Target target_host = target_host_; - if (!target_host_.defined()) { - for (const auto& it : targets_) { - if (it.second->kind->device_type == kDLCPU) { - target_host = it.second; - break; - } - } - } - return target_host; - } - protected: std::unique_ptr executor_codegen_; - /*! \brief target device */ - TargetMap targets_; - /*! \brief target host device */ - tvm::Target target_host_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ @@ -552,6 +471,8 @@ class RelayBuildModule : public runtime::ModuleNode { * - aot: use the aot executor */ String executor_; + /*! \brief Collects all the targets and scopes we need during compilation. */ + CompilationConfig config_; }; runtime::Module RelayBuildCreate() { diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index d32ded3796886..1fe8a5c953558 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -203,7 +203,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorattrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (auto v : storage_info->device_types) { - device_types.push_back(static_cast(v)); + for (const auto& se_scope : storage_info->se_scopes) { + // TODO(mbs): Keeping only the device type. + ICHECK_GT(se_scope->device_type(), 0); + device_types.push_back(se_scope->device_type()); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { @@ -446,7 +448,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const CallNode* call_node) override { relay::Call call = GetRef(call_node); - auto props = GetOnDeviceProps(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined()) { // See through "on_device" calls. return VisitExpr(props.body); @@ -472,6 +474,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatortuple); return {vtuple[op->index]}; } + std::vector VisitExpr_(const OpNode* op) override { LOG(FATAL) << "All OpNodes should have been expanded"; return {}; diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 961252a14fa76..c92caba0862fe 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -51,21 +51,19 @@ struct StorageToken { size_t max_bytes{0}; /*! \brief The corresponding tensor type. */ TensorType ttype{nullptr}; - /*! \brief Device on which memory will reside. */ - Device device{kInvalidDeviceType, -1}; + /*! \brief SEScope on which the memory will reside. */ + SEScope se_scope = SEScope::FullyUnconstrained(); /*! \brief The storage id */ int64_t storage_id{-1}; - bool is_valid() const { return device.device_type != kInvalidDeviceType; } + bool is_valid() const { return !se_scope->IsFullyUnconstrained(); } - bool is_compatible(const StorageToken& that) const { - return device.device_type == that.device.device_type; - } + bool is_compatible(const StorageToken& that) const { return se_scope == that.se_scope; } std::string ToString() const { std::ostringstream os; - os << "{id: " << storage_id << ", bytes: " << max_bytes << ", type: " << PrettyPrint(ttype) - << ", device: " << device.device_type << "}"; + os << "{storage_id: " << storage_id << ", max_bytes: " << max_bytes + << ", ttype: " << PrettyPrint(ttype) << ", se_scope: " << se_scope << "}"; return os.str(); } }; @@ -160,14 +158,14 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * the result of evaluating \p op. */ void CreateToken(const ExprNode* op, bool can_realloc) { - return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef(op)), can_realloc); + return CreateTokenOnDevice(op, GetSEScope(GetRef(op)), can_realloc); } /*! * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding * the result of evaluating \p op on \p device_type. */ - virtual void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + virtual void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) = 0; }; @@ -186,16 +184,13 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, - bool can_realloc) override { + void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) override { ICHECK(!token_map_.count(op)); std::vector tokens; for (const auto& ttype : FlattenTupleType(op->checked_type())) { - StorageToken* token = arena_->make(); + auto* token = arena_->make(); token->ttype = ttype; - // TODO(mbs): Should be TargetDevice. - token->device.device_type = device_type; - token->device.device_id = 0; + token->se_scope = se_scope; tokens.push_back(token); } token_map_[op] = tokens; @@ -251,8 +246,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; - std::vector device_types; + storage_ids.reserve(kv.second.size()); + std::vector se_scopes; + se_scopes.reserve(kv.second.size()); std::vector sid_sizes_byte; + sid_sizes_byte.reserve(kv.second.size()); for (StorageToken* tok : kv.second) { VLOG(1) << "token: " << tok->ToString(); @@ -261,10 +259,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(static_cast(tok->device.device_type)); + se_scopes.push_back(tok->se_scope); sid_sizes_byte.push_back(GetMemorySize(tok)); } - auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); + auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(se_scopes), + std::move(sid_sizes_byte)); smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. @@ -279,20 +278,20 @@ class StorageAllocator : public StorageAllocaBaseVisitor { protected: // override create token by getting token as prototype requirements. - void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) final { ICHECK(!token_map_.count(op)); auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { - ICHECK_EQ(tok->device.device_type, device_type); + ICHECK(tok->se_scope == se_scope); if (can_realloc) { tokens.push_back(Request(tok)); } else { // Allocate a new token, StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); - allocated_tok->device = tok->device; + allocated_tok->se_scope = tok->se_scope; // ensure it never get de-allocated. allocated_tok->ref_counter += 1; tokens.push_back(allocated_tok); @@ -363,7 +362,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \param size The original size. * \param word_size The element size. */ - static size_t DivRoundUp(size_t size, size_t word_size) { + static int64_t DivRoundUp(int64_t size, int64_t word_size) { return (size + word_size - 1) / word_size; } /*! @@ -390,16 +389,19 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \brief Get the memory requirement. * \param prototype The prototype token. * \return The required memory size. + * + * TODO(mbs): Gf GetMemorySizeBytes in aot_executor_codegen.cc, + * CalculateRelayExprSizeBytes in utils.cc */ - size_t GetMemorySize(StorageToken* prototype) { + static int64_t GetMemorySize(StorageToken* prototype) { TensorType ttype = prototype->ttype; ICHECK(ttype.defined()); - size_t size = 1; + int64_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; - size *= static_cast(pval[0]); + size *= pval[0]; } size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); return size; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 13b855624461a..ecca1fac03d97 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -908,16 +908,12 @@ class Interpreter : public ExprFunctor, * functions needed by the rewritten module. */ IRModule Prepare(IRModule mod, CompilationConfig config) { - tec::TargetMap tec_target_map; - for (const auto& pair : config->legacy_target_map) { - tec_target_map.emplace(static_cast(pair.first->value), pair.second); - } // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), // Figure out which devices should be used to execute. // TODO(mbs): Should ignore all existing annotations when constant folding - transform::PlanDevices(config->default_primitive_se_scope->device_type()), + transform::PlanDevices(std::move(config)), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -927,8 +923,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(tec_target_map, /*module_name=*/"intrp", - [](Function func) { /* no-op */ })}); + tec::LowerTEPass(/*module_name=*/"intrp", [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 163bb9f71f9c1..8a6c344fb0844 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -43,6 +43,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -359,21 +360,6 @@ TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECo using AnalysisRemapping = std::unordered_map; -std::tuple IsDeviceCopy(const Function& func) { - if (auto call_node = func->body.as()) { - if (auto op_node = call_node->op.as()) { - if (op_node->name == "device_copy") { - auto attrs = call_node->attrs.as(); - auto dst = attrs->dst_dev_type; - auto src = attrs->src_dev_type; - return std::tuple(true, src, dst); - } - } - } - - return std::tuple(false, -1, -1); -} - /*! * \brief Rewrites call expressions to Relay functions marked as 'primitive' * to calls to the corresponding TIR primitive for the appropriate target. @@ -415,11 +401,10 @@ std::tuple IsDeviceCopy(const Function& func) { */ class LowerTensorExprMutator : public DeviceAwareExprMutator { public: - LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, ProcessFn process_fn, - const String& module_name, TECompiler compiler) + LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name, + TECompiler compiler) : DeviceAwareExprMutator(module), module_(module), - targets_(targets), process_fn_(process_fn), module_name_(module_name), compiler_(compiler), @@ -484,7 +469,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } // Non-External Relay Function - VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); + VLOG(1) << "lowering to target " << target->ToDebugString() << " for primitive:" << std::endl + << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; @@ -514,14 +500,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - auto device_copy = IsDeviceCopy(func); - if (std::get<0>(device_copy)) { - // Record that device copy source and destination devices so the device planner can - // still follow along. - auto source_device = std::get<1>(device_copy); - auto dst_device = std::get<2>(device_copy); - tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); - tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); + DeviceCopyProps props = GetDeviceCopyProps(func); + if (props.body.defined()) { + // Record the device copy source and destination SEScopes so the device planner can + // still follow along even after lowering. + tir_call_attrs->metadata.Set("src_se_scope", props.src_se_scope); + tir_call_attrs->metadata.Set("dst_se_scope", props.dst_se_scope); } tir_call_attrs->metadata.Set("relay_attrs", func->attrs); @@ -534,8 +518,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // on the host cpu irrespective of where the primitive runs. // TODO(mbs): Cleanup target handling. Target shape_target("llvm"); - VLOG(1) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; + VLOG(1) << "lowering to target " << shape_target->ToDebugString() + << " for dynamic shape function for primitive"; CCacheKey shape_key(func, shape_target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -624,9 +608,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { target = Target("ext_dev"); } else { // The target corresponding to the call_node expression's annotation. - DLDeviceType device_type = GetInScopeDeviceType(call); - // TODO(mbs): Replace device_type with target so this lookup is unnecessary. - target = GetTargetFromInteger(device_type, targets_); + SEScope se_scope = GetSEScope(call); + ICHECK(!se_scope->IsFullyUnconstrained()); + target = se_scope->target; + ICHECK(target.defined()); } // Lower the primitive function for that target. @@ -640,7 +625,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } IRModule module_; - TargetMap targets_; ProcessFn process_fn_; // Map from in-scope let-bound variables to Relay functions known to be // primitive. We'll rewrite these to the fresh global vars bound to the lowered @@ -686,11 +670,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } } -Pass LowerTensorExpr(TargetMap targets, const String& module_name, TECompiler compiler, +Pass LowerTensorExpr(const String& module_name, TECompiler compiler, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, targets, process_fn, module_name, compiler); + LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -707,6 +691,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa } // This is a Map> + // TODO(mbs): Collapsing SEScopes to just device type. std::unordered_map, backend::EnumClassHash> sid_workspace; // This is a Map @@ -717,15 +702,15 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // Initialize the mapping from all storage identifiers to workspace sizes, // the amount of device io, and the device constants. for (const auto& kv : storage_info_map) { - backend::StorageInfo storage_info = kv.second; - std::vector storage_ids = storage_info->storage_ids; - std::vector devices = storage_info->device_types; - - CHECK_EQ(storage_ids.size(), devices.size()); - for (uint32_t i = 0; i < devices.size(); i++) { - sid_workspace[devices[i]][storage_ids[i]] = 0; - device_io[devices[i]] = 0; - device_consts[devices[i]] = 0; + const backend::StorageInfo& storage_info = kv.second; + const std::vector& storage_ids = storage_info->storage_ids; + const std::vector& se_scopes = storage_info->se_scopes; + CHECK_EQ(storage_ids.size(), se_scopes.size()); + for (uint32_t i = 0; i < se_scopes.size(); i++) { + DLDeviceType device_type = se_scopes[i]->device_type(); + sid_workspace[device_type][storage_ids[i]] = 0; + device_io[device_type] = 0; + device_consts[device_type] = 0; } } @@ -754,18 +739,20 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa << PrettyPrint(expr->checked_type()) << std::endl << "has size " << size_bytes << " and storage info:" << std::endl << storage_info; - std::vector storage_ids = storage_info->storage_ids; - std::vector devices = storage_info->device_types; + const std::vector& storage_ids = storage_info->storage_ids; + const std::vector& se_scopes = storage_info->se_scopes; if (expr->IsInstance()) { - for (const auto& dev : devices) { - ICHECK_EQ(device_consts.count(dev), 1); - device_consts[dev] += size_bytes; + for (const auto& se_scope : se_scopes) { + DLDeviceType device_type = se_scope->device_type(); + ICHECK_EQ(device_consts.count(device_type), 1); + device_consts[device_type] += size_bytes; } } else if (expr->IsInstance() || expr.same_as(func->body)) { - CHECK_GE(devices.size(), 1) << "must be at least one device"; - for (const auto& dev : devices) { - device_io[dev] += size_bytes; + CHECK_GE(se_scopes.size(), 1) << "must be at least one device"; + for (const auto& se_scope : se_scopes) { + DLDeviceType device_type = se_scope->device_type(); + device_io[device_type] += size_bytes; } } else { // TODO(@electriclilies): This code is never being called which means sid_workspace is not @@ -775,8 +762,9 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // Here we record the largest size of the tensor // that share the same storage id, because storage_id will // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { - sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + DLDeviceType device_type = se_scopes[i]->device_type(); + if (size_bytes > sid_workspace[device_type][storage_ids[i]]) { + sid_workspace[device_type][storage_ids[i]] = size_bytes; } } } @@ -821,8 +809,9 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa constant_sizes.Set(tgt, dev_and_size.second); } - backend::FunctionInfo func_info(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, - relay_primfuncs); + backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes), + std::move(constant_sizes), std::move(tir_primfuncs), + std::move(relay_primfuncs)); VLOG(1) << "func_info: " << func_info; return std::move(func_info); } @@ -880,6 +869,7 @@ void UpdateFunctionMetadata(Function relay_func, workspace_sizes.Set(prim_fn_target, workspace_size); // Calculating size for I/O + // TODO(mbs): See also the other three utils for calculating tensor bytesize. for (auto const& param : prim_fn->params) { auto p_shape = prim_fn->buffer_map[param]->shape; int num_of_elements = 1; @@ -900,8 +890,9 @@ void UpdateFunctionMetadata(Function relay_func, relay_primfuncs.Set(prim_fn_target, relay_func); } - backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, - tir_primfuncs, relay_primfuncs); + backend::FunctionInfo fi = backend::FunctionInfo( + std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes), + std::move(tir_primfuncs), std::move(relay_primfuncs)); VLOG(1) << "FunctionInfo: " << prim_fn_var.value()->name_hint << " = " << PrettyPrint(fi); @@ -910,11 +901,11 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -IRModule LowerTE(const IRModule& module, TargetMap targets, const String& module_name, +IRModule LowerTE(const IRModule& module, const String& module_name, std::function process_fn) { TECompiler compiler; - auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module); + auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module); backend::UpdateAutoSchedulerOpWeights(compiler); @@ -959,12 +950,9 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(TargetMap targets, const String& module_name, - std::function process_fn) { - runtime::TypedPackedFunc pass_func = [=](IRModule module, - PassContext ctx) { - return LowerTE(module, targets, module_name, process_fn); - }; +Pass LowerTEPass(const String& module_name, std::function process_fn) { + runtime::TypedPackedFunc pass_func = + [=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn); }; return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index d0401e9605f7f..da7333d64d463 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -173,7 +173,6 @@ Map GetPerTargetModules(IRModule mod); * to TE expressions, schedules them, and then to TIR. * * \param module The IRModule. - * \param targets The mapping for devices to targets. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process @@ -181,8 +180,8 @@ Map GetPerTargetModules(IRModule mod); * \return The lowered module, see above. */ IRModule LowerTE( - const IRModule& module, TargetMap targets, backend::StaticMemoryPlan memory_plan, - const String& module_name, ProcessFn process_fn = [](Function f) {}); + const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, + ProcessFn process_fn = [](Function f) {}); /*! \brief Pass to lower an IRModule's primitive functions to TIR. * @@ -190,14 +189,12 @@ IRModule LowerTE( * to TE expressions, schedules them, and then to TIR. It annotates all functions * with their target. * - * \param targets The mapping for devices to targets. * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(TargetMap targets, const String& module_name, - std::function process_fn); +transform::Pass LowerTEPass(const String& module_name, std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 02caf56c66e65..9a1c428482e2d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -34,49 +34,56 @@ namespace backend { TVM_REGISTER_NODE_TYPE(StorageInfoNode); -StorageInfo::StorageInfo(std::vector storage_ids, std::vector device_types, - std::vector storage_sizes_in_bytes) { - auto n = make_object(); - n->storage_ids = std::move(storage_ids); - n->device_types = std::move(device_types); - n->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); - data_ = std::move(n); -} - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { const auto* node = ref.as(); - p->stream << "StorageInfoNode(\n" - << " storage_ids=["; + p->stream << "StorageInfoNode(" + << "storage_ids=["; for (auto id : node->storage_ids) { - p->stream << id << ", "; + p->stream << id << ","; } - p->stream << "],\n device_types=["; - for (auto device_type : node->device_types) { - p->stream << device_type << ", "; + p->stream << "], se_scopes=["; + for (const auto& se_scope : node->se_scopes) { + p->stream << se_scope << ","; } - p->stream << "],\n storage_size_in_bytes=["; + p->stream << "], storage_size_in_bytes=["; for (auto bytes : node->storage_sizes_in_bytes) { - p->stream << bytes << ", "; + p->stream << bytes << ","; } p->stream << "])"; }); +StorageInfo::StorageInfo(std::vector storage_ids, std::vector se_scopes, + std::vector storage_sizes_in_bytes) { + ICHECK_EQ(storage_ids.size(), se_scopes.size()); + ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size()); + auto node = make_object(); + node->storage_ids = std::move(storage_ids); + node->se_scopes = std::move(se_scopes); + node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); + data_ = std::move(node); +} + +// This is the legacy interface for devices as DLDeviceTypes (represented by integers) TVM_REGISTER_GLOBAL("relay.ir.StorageInfo") - .set_body_typed([](const Array& sids, const Array& dev_types, + .set_body_typed([](const Array& sids, const Array& device_types, const Array& sizes_in_bytes) { - std::vector sids_v, sizes_v; - std::vector dev_types_v; + std::vector sids_v; + sids_v.reserve(sids.size()); for (auto s : sids) { sids_v.push_back(s); } - for (auto d : dev_types) { - dev_types_v.push_back(static_cast(static_cast(d))); + std::vector se_scopes_v; + se_scopes_v.reserve(device_types.size()); + for (const auto& device_type : device_types) { + se_scopes_v.emplace_back(SEScope::ForDeviceType(device_type)); } + std::vector size_in_bytes_v; + size_in_bytes_v.reserve(sizes_in_bytes.size()); for (auto s : sizes_in_bytes) { - sizes_v.push_back(s); + size_in_bytes_v.push_back(s); } - return StorageInfo(sids_v, dev_types_v, sizes_v); + return StorageInfo(std::move(sids_v), std::move(se_scopes_v), std::move(size_in_bytes_v)); }); TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) { @@ -87,10 +94,11 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageI return ids; }); +// This is the legacy interface for devices as DLDeviceTypes (represented by integers) TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { Array device_types; - for (auto id : si->device_types) { - device_types.push_back(id); + for (const auto& se_scope : si->se_scopes) { + device_types.push_back(se_scope->device_type()); } return device_types; }); @@ -116,7 +124,8 @@ TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan") return StaticMemoryPlan(expr_to_storage_info); }); -// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc +// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc, GetMemorySize in +// graph_plan_memory.cc int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type); @@ -166,7 +175,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); -Array GetPassPrefix(const Map& targets, bool is_vm) { +Array GetPassPrefix(bool is_homegeneous, bool is_vm) { Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); @@ -175,7 +184,7 @@ Array GetPassPrefix(const Map& targets, bool is pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { + if (is_homegeneous) { pass_seqs.push_back(transform::Legalize()); } @@ -217,8 +226,8 @@ Array GetPassPrefix(const Map& targets, bool is pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { + // Alter layout transformation is currently only applied to homogeneous execution. + if (is_homegeneous) { if (!is_vm) { pass_seqs.push_back(transform::InferType()); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 16cbe0e8dbcae..4224a99c26285 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -57,15 +58,17 @@ namespace backend { using Pass = tvm::transform::Pass; /*! - * \brief The static storage information produced by memory planning. + * \brief The static storage information for each Tensor in the result of a Relay expression + * (as per relay::FlattenTupleType). */ class StorageInfoNode : public Object { public: + // TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout. /*! \brief The set of storage ids where the expression is stored. */ std::vector storage_ids; - /* \brief The type of "virtual devices" these expressions are stored on. */ - std::vector device_types; - /* \brief The sizes of each storage element. */ + /* \brief The SEScopes these expressions are stored within. */ + std::vector se_scopes; + /* \brief The sizes of each storage element, in bytes. */ std::vector storage_sizes_in_bytes; // TODO(@jroesch): expose the fields @@ -78,7 +81,7 @@ class StorageInfoNode : public Object { /*! \brief The storage information for a single expression. */ class StorageInfo : public ObjectRef { public: - StorageInfo(std::vector storage_ids, std::vector device_types, + StorageInfo(std::vector storage_ids, std::vector se_scopes, std::vector storage_sizes_in_bytes); TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode); }; @@ -442,11 +445,11 @@ inline bool IsMetaScheduleEnabled() { * difference. This function unifies the shared optimization pass prefix between vm and graph * runtime, and returns the pass prefix given the backend type. * - * \param targets The device type to `Target` mapping. - * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. + * \param is_homogenous True if all primitives are to be executed on the same device and target. + * \param is_vm True if passes are to be used for the vm executor. * \return An array of passes. */ -Array GetPassPrefix(const TargetMap& targets, bool is_vm); +Array GetPassPrefix(bool is_homogenous, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c4c50c6c5646d..be352177879a9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -81,6 +81,9 @@ using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; +/*! \brief The host device is always stored at device index 0. */ +constexpr Index kHostDeviceIndex = 0; + // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -93,9 +96,9 @@ using MatchValuePtr = std::shared_ptr; // A runtime object that resides in a register struct RegisterValue : MatchValue { // The register num - RegName rergister_num; + RegName register_num; - explicit RegisterValue(RegName reg) : rergister_num(reg) {} + explicit RegisterValue(RegName reg) : register_num(reg) {} ~RegisterValue() {} }; @@ -227,44 +230,17 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } -/*! - * \brief Create a default type. - * \param device_type The device type index. - * \return the default target for the device. - */ -Target CreateDefaultTarget(int device_type) { - std::string name = runtime::DeviceName(device_type); - if (name == "cpu") return Target("llvm"); - if (name == "cuda") return Target("cuda"); - return Target(name); -} - -int GetFallbackDevice() { - transform::PassContext pass_ctx = PassContext::Current(); - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev = opt_fallback_dev.value(); - ICHECK_GT(fallback_dev->value, 0U); - return fallback_dev->value; -} - class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, SEScope host_se_scope) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), context_(context), - target_host_(target_host) { - CheckAndUpdateHostConsistency(&targets, &target_host); - for (const auto& it : targets) { - targets_[it.first->value] = it.second; - } - target_host_ = target_host; - } + host_se_scope_(std::move(host_se_scope)) {} VMFunction Compile(const GlobalVar& var, const Function& func) { - std::vector params_device_type; + std::vector param_device_indexes; if (IsClosure(func)) { // After lifting we'll have functions of the form: // fn(closure args) { fn(lifted function args) { body } } @@ -273,16 +249,21 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); std::vector params; - std::vector param_device_types; + std::vector param_se_scopes; params.reserve(func->params.size() + inner_func->params.size()); - param_device_types.reserve(func->params.size() + inner_func->params.size()); + param_se_scopes.reserve(func->params.size() + inner_func->params.size()); + param_device_indexes.reserve(func->params.size() + inner_func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { params.emplace_back(func->params[i]); - params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i); + param_se_scopes.push_back(param_se_scope); + param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } for (size_t i = 0; i < inner_func->params.size(); ++i) { params.emplace_back(inner_func->params[i]); - params_device_type.push_back(GetFunctionParamDeviceType(inner_func.get(), i)); + SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i); + param_se_scopes.push_back(param_se_scope); + param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } std::vector type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); @@ -294,22 +275,17 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, type_params, func->attrs, func->span); - VisitExpr(MaybeFunctionOnDevice(flattened_func, params_device_type, - GetFunctionResultDeviceType(inner_func.get()))); + VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes, + GetFunctionResultSEScope(inner_func.get()))); } else { - params_device_type.reserve(func->params.size()); + param_device_indexes.reserve(func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { - params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + param_device_indexes.push_back(GetDeviceIndex(GetFunctionParamSEScope(func.get(), i))); } VisitExpr(func); } - std::vector params_device_type_index; - params_device_type_index.reserve(params_device_type.size()); - for (auto device_type : params_device_type) { - params_device_type_index.push_back(static_cast(device_type)); - } return VMFunction(var->name_hint, params_, instructions_, registers_num_, - params_device_type_index); + std::move(param_device_indexes)); } /*! \brief Attrs objects for each op. */ @@ -352,6 +328,48 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { instructions_.push_back(instr); } + /*! + * \brief Returns the "device index" to represent \p se_scope for primitives + * in emitted code. Note that the host device is always at index 0. + */ + Index GetDeviceIndex(const SEScope& se_scope) { + VLOG(2) << "getting device index for " << se_scope; + auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); + if (itr != context_->se_scopes_.end()) { + VLOG(2) << "reusing existing scope"; + return std::distance(context_->se_scopes_.begin(), itr); + } + + ICHECK_GT(context_->se_scopes_.size(), 0); + ICHECK_NE(se_scope, host_se_scope_); + + if (se_scope->device_type() == context_->se_scopes_.front()->device_type()) { + // It's ok if we see distinct scopes which share the host device type. This is because + // we allow the SEScope for the host to be different from the SEScope for primitive + // operations which happen to be, eg, on the CPU. + return 0; + } + + // However, otherwise we allow at most one SEScope per device type. + // TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy + // instructions can do the right thing. + itr = std::find_if(context_->se_scopes_.begin() + 1, context_->se_scopes_.end(), + [&se_scope](const SEScope& existing_se_scope) { + return existing_se_scope->device_type() == se_scope->device_type(); + }); + CHECK(itr == context_->se_scopes_.end()) + << "The VM does not currently support using more than one device with the same device type " + "for primitives, however the program is using the distinct scopes " + << se_scope << " and " << *itr << " of device type " << se_scope->device_type(); + + ICHECK(se_scope != host_se_scope_); + Index index = context_->se_scopes_.size(); + VLOG(2) << "adding new scope"; + context_->se_scopes_.push_back(se_scope); + + return index; + } + using DeviceAwareExprFunctor::VisitExpr_; void VisitExpr_(const ConstantNode* const_node) final { @@ -359,7 +377,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { NDArray data = const_node->data; size_t konst_idx = context_->constants.size(); auto con = GetRef(const_node); - context_->const_device_type.push_back(GetInScopeDeviceType(con)); + context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con))); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } @@ -463,7 +481,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - tec::CCacheKey key(func, target_host_); + tec::CCacheKey key(func, host_se_scope_->target); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -498,7 +516,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { argument_registers)); } - void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { + void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs, + SEScope se_scope) { std::vector argument_registers; ICHECK(func->HasNonzeroAttr(attr::kPrimitive)) @@ -531,13 +550,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { - int dev_type = GetInScopeDeviceType(func); - if (targets_.count(dev_type) == 0) { - target = CreateDefaultTarget(dev_type); - } else { - target = targets_[dev_type]; - } + target = se_scope->target; } + ICHECK(target.defined()) << "No target for function:" << std::endl << PrettyPrint(func); tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; @@ -577,9 +592,11 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { OpMatch matcher; matcher .Match("vm.invoke_tvm_op", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { ICHECK_EQ(args.size(), 3); - EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); + EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2], + GetSEScope(GetRef(call_node))); }) .Match("memory.alloc_tensor", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -639,7 +656,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto dtype = alloc_attrs->dtype; Emit(Instruction::AllocStorage(size_register, alignment, dtype, - alloc_attrs->device_type, NewRegister())); + GetDeviceIndex(alloc_attrs->se_scope), + NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -671,17 +689,17 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); }) .Match("device_copy", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { ICHECK_EQ(args.size(), 1U); this->VisitExpr(args[0]); auto src_reg = last_register_; auto device_copy_attrs = attrs.as(); ICHECK(device_copy_attrs != nullptr) << "Must be the device copy attrs"; - Index src_device_type = device_copy_attrs->src_dev_type; - Index dst_device_type = device_copy_attrs->dst_dev_type; - Emit(Instruction::DeviceCopy(src_reg, src_device_type, dst_device_type, - NewRegister())); + Emit(Instruction::DeviceCopy( + src_reg, GetDeviceIndex(device_copy_attrs->src_se_scope), + GetDeviceIndex(device_copy_attrs->dst_se_scope), NewRegister())); }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -781,7 +799,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { RegName CompileMatchValue(MatchValuePtr val) { if (std::dynamic_pointer_cast(val)) { auto r = std::dynamic_pointer_cast(val); - return r->rergister_num; + return r->register_num; } else { auto path = std::dynamic_pointer_cast(val); auto p = CompileMatchValue(path->parent); @@ -858,18 +876,15 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t registers_num_; /*! \brief Global shared meta data */ VMCompilerContext* context_; - /*! \brief Target devices. */ - std::unordered_map targets_; - /*! \brief Host target. */ - Target target_host_; + /*! \brief SEScope for data and computation which must reside on a CPU. */ + SEScope host_se_scope_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 3); - IRModule mod = args[0]; - this->Lower(mod, args[1], args[2]); + this->Lower(args[0], args[1], args[2]); }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -909,15 +924,16 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, - const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) { exec_ = make_object(); - targets_ = targets; - target_host_ = target_host; - CheckAndUpdateHostConsistency(&targets_, &target_host_); + config_ = CompilationConfig(PassContext::Current(), std::move(targets), std::move(target_host)); + + // The first device is always for the host. + CHECK(context_.se_scopes_.empty()); + context_.se_scopes_.push_back(config_->host_se_scope); // Run the optimizations necessary to target the VM. - context_.module = OptimizeModule(mod, targets_, target_host_); + context_.module = OptimizeModuleImpl(std::move(mod)); // Populate the global map. // @@ -933,7 +949,7 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_); + VMFunctionCompiler func_compiler(&context_, config_->host_se_scope); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -947,17 +963,27 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, } } + // Populate virtual devices and the host device index. + for (const auto& se_scope : context_.se_scopes_) { + ICHECK(!se_scope->IsFullyUnconstrained()); + ICHECK_GT(se_scope->device_type(), 0); + // TODO(mbs): We forget the memory scope. + exec_->virtual_devices.push_back( + Device{/*device_type=*/se_scope->device_type(), /*device_id=*/se_scope->virtual_device_id}); + } + exec_->host_device_index = kHostDeviceIndex; + // populate constants - for (auto data : context_.constants) { + for (const auto& data : context_.constants) { exec_->constants.push_back(data); } - for (auto i : context_.const_device_type) { - exec_->const_device_type.push_back(i); + for (auto index : context_.const_device_indexes) { + exec_->const_device_indexes.push_back(index); } // update global function map - for (auto gv : context_.global_map) { + for (const auto& gv : context_.global_map) { exec_->global_map.insert({gv.first->name_hint, gv.second}); } @@ -967,22 +993,21 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } -#if USE_RELAY_DEBUG - for (const auto& vm_func : exec_->functions) { - VLOG(1) << vm_func << "-------------"; - } -#endif // USE_RELAY_DEBUG + VLOG(1) << std::endl + << "-------------------------------------------------" << std::endl + << exec_->GetVirtualDevices() << exec_->GetConstants() << exec_->GetBytecode() + << "-------------------------------------------------"; backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { +transform::Sequential MemoryOpt(const SEScope& cpu_se_scope) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -991,7 +1016,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) pass_seqs.push_back(transform::FuseOps()); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Fuse the shape functions. pass_seqs.push_back(transform::FuseOps()); @@ -1009,7 +1034,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) pass_seqs.push_back(transform::FuseOps()); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -1019,15 +1044,22 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) // instructions need to access to constant // pass_seqs.push_back(transform::LiftConstants()); - return transform::Sequential(pass_seqs); + return transform::Sequential(std::move(pass_seqs)); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, - const Target& target_host_arg) { +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, + const Target& target_host) { + config_ = CompilationConfig(PassContext::Current(), targets, target_host); + // The first device always corresponds to the host. + CHECK(context_.se_scopes_.empty()); + context_.se_scopes_.push_back(config_->host_se_scope); + // TODO(mbs): exec_ is not allocated. What is the API here? + CHECK(exec_ == nullptr); + return OptimizeModuleImpl(std::move(mod)); +} + +IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VMCompiler::OptimizeModule"; - TargetMap targets = targets_arg; - Target target_host = target_host_arg; - CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); ICHECK(base_func->IsInstance()) @@ -1037,29 +1069,24 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs = relay::backend::GetPassPrefix(targets, true); + Array pass_seqs = relay::backend::GetPassPrefix( + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); - // TODO(mbs): Reconcile with relay/backend/build_module.cc - DLDeviceType default_device_type; - if (targets_arg.size() == 1) { - default_device_type = - static_cast(static_cast((*targets_arg.begin()).first->value)); - } else { - default_device_type = static_cast(GetFallbackDevice()); - } - pass_seqs.push_back(PlanDevices(default_device_type)); + // Always plan devices so the remaining passes don't need to distinguish homogeneous vs + // hetrogeneous execution. + pass_seqs.push_back(transform::PlanDevices(config_)); pass_seqs.push_back(transform::FuseOps()); // Do layout rewrite for auto-scheduler. transform::PassContext pass_ctx = PassContext::Current(); - if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { - const auto& target = (*targets.begin()).second; + if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { Pass major_pass = transform::AutoSchedulerLayoutRewrite(); bool enable_layout_rewrite_targets = - target->kind->device_type == kDLCPU || target->GetAttr("device", "") == "mali"; + config_->optional_homogeneous_target->kind->device_type == kDLCPU || + config_->optional_homogeneous_target->GetAttr("device", "") == "mali"; if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { - With tctx(target); + With tctx(config_->optional_homogeneous_target); pass_seqs.push_back(major_pass); // Defuse ops to fold constants, then fuse them again pass_seqs.push_back(transform::DefuseOps()); @@ -1080,18 +1107,18 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(target_host, targets)); + pass_seqs.push_back(MemoryOpt(config_->host_se_scope)); pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::LabelOps()); transform::Sequential seq(pass_seqs); tvm::With ctx(pass_ctx); - if (targets.size() == 1) { - const auto& it = targets.begin(); - With tctx((*it).second); - return seq(mod); + if (config_->optional_homogeneous_target.defined()) { + With tctx(config_->optional_homogeneous_target); + return seq(std::move(mod)); + } else { + return seq(std::move(mod)); } - return seq(mod); } void VMCompiler::PopulateGlobalMap() { @@ -1138,13 +1165,14 @@ void VMCompiler::Codegen() { runtime::Module lib; if (funcs.size() > 0) { - lib = tvm::build(funcs, target_host_); + lib = tvm::build(funcs, config_->host_target); } else { // There is no function handled by TVM. We create a virtual main module // to make sure a DSO module will be also available. lib = codegen::CSourceModuleCreate(";", "", Array{}); } - lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata()); + lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, + runtime::Metadata()); exec_->SetLib(lib); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 5b51d7821d78b..2edec70d5c3be 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -78,17 +78,21 @@ struct VMCompilerContext { tec::TECompiler compiler; // List of constants std::vector constants; - // Device type for constants - std::vector const_device_type; + // Device indexes for constants + std::vector const_device_indexes; // List of cached functions std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; + // The SEScopes corresponding to each device index. The first device always corresponds + // to the host device, and all remaining devices are for the primitive operations. + std::vector se_scopes_; }; class VMCompiler : public runtime::ModuleNode { public: - virtual ~VMCompiler() {} + VMCompiler() = default; + virtual ~VMCompiler() = default; virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); @@ -110,7 +114,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, TargetMap targets, Target target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -128,6 +132,8 @@ class VMCompiler : public runtime::ModuleNode { */ IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); + IRModule OptimizeModuleImpl(IRModule mod); + /*! * \brief Populate the global function names in a map where the value is used * as the index by the VMFunctions. @@ -135,10 +141,8 @@ class VMCompiler : public runtime::ModuleNode { void PopulateGlobalMap(); protected: - /*! \brief Target devices. */ - TargetMap targets_; - /*! \brief Target host device. */ - tvm::Target target_host_; + /*! \brief Targets and scopes needed for compilation. */ + CompilationConfig config_; /*! \brief Global shared meta data */ VMCompilerContext context_; /*! \brief Compiled executable. */ diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index d9a2b8b91fa35..ffd0e466eb24e 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -112,7 +112,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; - std::vector captured_var_device_types; + std::vector captured_var_se_scopes; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -120,7 +120,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { continue; } captured_vars.push_back(var); - captured_var_device_types.push_back(GetInScopeDeviceType(var)); + captured_var_se_scopes.push_back(GetSEScope(var)); } // Freshen all the captured vars. @@ -132,7 +132,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { rebinding_map.Set(free_var, var); } - DLDeviceType result_device_type = GetInScopeDeviceType(func_node->body); + SEScope result_se_scope = GetSEScope(func_node->body); if (recursive) { if (!captured_vars.empty()) { @@ -195,8 +195,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); - lifted_func = - MaybeFunctionOnDevice(lifted_func, captured_var_device_types, result_device_type); + lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope); lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 9a2297a759626..e9441f1b3e58e 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -32,6 +32,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { @@ -529,11 +530,11 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - std::vector new_param_device_types; + std::vector new_param_se_scopes; for (size_t i = 0; i < func->params.size(); ++i) { if (!args_map.count(func->params[i])) { new_params.push_back(func->params[i]); - new_param_device_types.push_back(GetFunctionParamDeviceType(func, i)); + new_param_se_scopes.push_back(GetFunctionParamSEScope(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { @@ -541,7 +542,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); + ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -549,19 +550,19 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); - if (GetFunctionResultDeviceType(func) != kInvalidDeviceType) { + if (!GetFunctionResultSEScope(func)->IsFullyUnconstrained()) { // TODO(mbs): The function has been annotated with a device, which means we are supposed // to be preserving device annotations on every transformation. However there's no // such context for the free vars in args_map. LOG(WARNING) << "introduced free var '" << PrettyPrint(v) << "' into function body but no device is known for it"; } - new_param_device_types.push_back(kInvalidDeviceType); + new_param_se_scopes.push_back(SEScope::FullyUnconstrained()); } } ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); + ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 27b61333c9eb4..bd3162dfde869 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -38,158 +38,6 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); - -const Op& OnDeviceOp() { - static const Op& op = Op::Get("on_device"); - return op; -} - -Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { - auto attrs = make_object(); - attrs->device_type = device_type; - attrs->is_fixed = is_fixed; - Span span = expr->span; - return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); -} - -Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { - if (device_type == kInvalidDeviceType) { - // Undefined signals no annotation is required. - return expr; - } - if (expr->IsInstance() || expr->IsInstance()) { - // These operators are device polymorphic so no annotation is required. - // TODO(mbs): The device planning pass does NOT currently support device polymorphism for - // constructors, so we could remove them from this condition. However most constructors - // accept type parameters, and it is not well-formed Relay to simply wrap such a - // constructor in an "on_device" call. So we'll pretend they are device polymorphic to - // avoid that difficultly. Overall ADTs need more work to be fully supported. - return expr; - } - if (expr->IsInstance() || expr->IsInstance()) { - // The device can be recovered from the binding site of the global or local variable. - return expr; - } - if (expr->IsInstance()) { - // If a primitive function then it is device polymorphic. Otherwise the device is captured - // by the function's attributes. - return expr; - } - OnDeviceProps props = GetOnDeviceProps(expr); - if (props.body.defined()) { - // Don't nest on_devices. - // If the inner and outer device types differ then we need to be careful: - // - If the inner on_device is_fixed then it disagrees with the outer. - // - If the outer on_device is_fixed then it implies a hidden device_copy - // Otherwise just use the inner device type and ignore the outer. - ICHECK(props.device_type == device_type || (!is_fixed && !props.is_fixed)); - return OnDevice(props.body, device_type, is_fixed || props.is_fixed); - } - return OnDevice(expr, device_type, is_fixed); -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr expr, int device_type, bool is_fixed) { - return OnDevice(expr, static_cast(device_type), is_fixed); - }); - -RELAY_REGISTER_OP("on_device") - .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") - .set_support_level(10) - .add_type_rel("Identity", IdentityRel) - .set_attrs_type_key("relay.attrs.OnDeviceAttrs") - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("TNonComputational", true); - -OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { - if (call_node->op == OnDeviceOp()) { - ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; - ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; - const auto* on_device_attrs = call_node->attrs.as(); - ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; - auto device_type = static_cast(on_device_attrs->device_type); - // Follow nesting: - // on_device(on_device(expr, device_type=1), device_type=2) == {expr, 1} - auto inner = GetOnDeviceProps(call_node->args[0]); - if (inner.body.defined()) { - return {inner.body, inner.device_type, on_device_attrs->is_fixed || inner.is_fixed}; - } else { - return {call_node->args[0], device_type, on_device_attrs->is_fixed}; - } - } - return {}; -} - -OnDeviceProps GetOnDeviceProps(const Expr& expr) { - if (const auto* call_node = expr.as()) { - return GetOnDeviceProps(call_node); - } - return {}; -} - -Function FunctionOnDevice(Function function, Array param_device_types, - Integer result_device_type) { - return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, - {tvm::attr::kResultDeviceType, result_device_type}}); -} - -Function FunctionOnDevice(Function function, const std::vector& param_device_types, - DLDeviceType result_device_type) { - Array arr; - arr.reserve(param_device_types.size()); - for (const auto device_type : param_device_types) { - arr.push_back(static_cast(device_type)); - } - return FunctionOnDevice(std::move(function), std::move(arr), - static_cast(result_device_type)); -} - -Function MaybeFunctionOnDevice(Function function, - const std::vector& param_device_types, - DLDeviceType result_device_type) { - if (std::all_of(param_device_types.begin(), param_device_types.end(), - [](DLDeviceType type) { return type == kInvalidDeviceType; }) && - result_device_type == kInvalidDeviceType) { - return function; - } - return FunctionOnDevice(function, param_device_types, result_device_type); -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") - .set_body_typed([](Function function, Array param_device_types, - int result_device_type) { - return FunctionOnDevice(function, param_device_types, - static_cast(result_device_type)); - }); - -DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { - auto opt_integer = function_node->GetAttr(tvm::attr::kResultDeviceType); - if (!opt_integer) { - // No annotation. - return kInvalidDeviceType; - } - return static_cast(opt_integer.value()->value); -} - -DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { - ICHECK_LT(i, function_node->params.size()) - << "param index " << i << " out of range for function of arity " - << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamDeviceTypes); - if (!opt_array) { - // No annotation. - return kInvalidDeviceType; - } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return static_cast(opt_array.value()[i]->value); -} - Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index d772df9b023a3..1675b7281ebb6 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -34,112 +34,6 @@ namespace tvm { namespace relay { -/*! \brief Returns the "on_device" operator. */ -const Op& OnDeviceOp(); - -/*! - * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. - * - * See \p OnDeviceAttrs for an overview. - */ -Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); - -/*! - * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed if the - * device for \p expr cannot otherwise be recovered by the lexical scoping convention. This means - * we will NOT wrap if: - * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations - * already in play. - * - \p expr is an operator or primitive function literal. These are device polymorphic. - * - \p expr is a non-primitive function literal. The device is captured by the - * "result_device_type" attribute on the function itself. - * - \p expr is a global var. The device is on the function attributes the global is bound to. - * - \p expr is a local var. The device is tracked by the device aware visitors for us. - * - \p expr is a constructor. These should eventually be device polymorphic but are currently - * in an in-between state at the moment. - */ -Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); - -/*! \brief Result of \p GetOnDeviceProps. */ -struct OnDeviceProps { - Expr body; // = null - DLDeviceType device_type = kInvalidDeviceType; - bool is_fixed = false; - - OnDeviceProps() = default; - - OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed) - : body(body), device_type(deviceType), is_fixed(isFixed) {} -}; - -/*! - * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is - * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p - * false. - */ -OnDeviceProps GetOnDeviceProps(const CallNode* call_node); - -/*! - * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an - * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false. - */ -OnDeviceProps GetOnDeviceProps(const Expr& expr); - -/*! - * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns - * \p expr directly. - */ -inline Expr IgnoreOnDevice(const Expr& expr) { - OnDeviceProps props = GetOnDeviceProps(expr); - return props.body.defined() ? props.body : expr; -} - -/*! - * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through - * any "on_device" annotations. - */ -template -const NodeType* AsIgnoringOnDevice(const Expr& expr) { - const auto* node = expr.as(); - if (node != nullptr) { - return node; - } - OnDeviceProps props = GetOnDeviceProps(expr); - if (!props.body.defined()) { - return nullptr; - } - return props.body.as(); -} - -/*! - * \brief Returns \p function annotated with "param_device_types" and "result_device_type" - * attributes capturing parameter and result devices types respectively. - */ -Function FunctionOnDevice(Function function, Array param_device_types, - Integer body_device_type); -Function FunctionOnDevice(Function function, const std::vector& param_device_types, - DLDeviceType body_device_type); - -/*! - * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and - * result device types are \p kInvalidDeviceType. - */ -Function MaybeFunctionOnDevice(Function function, - const std::vector& param_device_types, - DLDeviceType result_device_type); - -/*! - * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType - * if function does not have "result_device_type" annotation. - */ -DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); - -/*! - * \brief Returns the device type for the \p i'th parameter of \p function_node, or - * \p kInvalidDeviceType if function does not have "param_device_types" annotation. - */ -DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); - /*! \brief Wraps \p data in a "stop_fusion" annotation. */ Expr StopFusion(Expr data); diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index dce89aa91b65a..d086eb1dc1840 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -30,6 +30,8 @@ #include #include +#include + #include "../../transforms/infer_layout_utils.h" #include "../type_relations.h" @@ -44,29 +46,27 @@ const Op& DeviceCopyOp() { return op; } -Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { +Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { + ICHECK(!src_se_scope->IsFullyUnconstrained()); + ICHECK(!dst_se_scope->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; + attrs->src_se_scope = std::move(src_se_scope); + attrs->dst_se_scope = std::move(dst_se_scope); Span span = expr->span; - return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span); + return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + std::move(span)); } -Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { - if (src_dev_type == dst_dev_type) { +TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy").set_body_typed(DeviceCopy); + +Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { + if (src_se_scope == dst_se_scope) { + // No copy needed. return expr; } - ICHECK_NE(src_dev_type, kInvalidDeviceType); - ICHECK_NE(dst_dev_type, kInvalidDeviceType); - return DeviceCopy(expr, src_dev_type, dst_dev_type); + return DeviceCopy(std::move(expr), std::move(src_se_scope), std::move(dst_se_scope)); } -TVM_REGISTER_GLOBAL("relay.op._make.device_copy") - .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) { - return DeviceCopy(expr, static_cast(src_dev_type), - static_cast(dst_dev_type)); - }); - RELAY_REGISTER_OP("device_copy") .describe(R"code( Copy data from one tensor to another. The source and destination might be @@ -92,16 +92,14 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { ICHECK(call_node->attrs.defined()) << "device_copy requires attributes"; const auto* device_copy_attrs = call_node->attrs.as(); ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; - auto src_dev_type = static_cast(device_copy_attrs->src_dev_type); - auto dst_dev_type = static_cast(device_copy_attrs->dst_dev_type); // Follow nesting: - // device_copy(device_copy(expr, src_dev_type=1, dst_dev_type=2), - // src_dev_type=2, dst_dev_type=3) ==> {expr, 1, 3} + // device_copy(device_copy(expr, src_se_scope=S, dst_se_scope=T), + // src_se_scope=T, dst_se_scope=U) ==> {expr, S, U} auto inner = GetDeviceCopyProps(call_node->args[0]); if (inner.body.defined()) { - return {inner.body, inner.src_dev_type, inner.dst_dev_type}; + return {inner.body, inner.src_se_scope, device_copy_attrs->dst_se_scope}; } else { - return {call_node->args[0], src_dev_type, dst_dev_type}; + return {call_node->args[0], device_copy_attrs->src_se_scope, device_copy_attrs->dst_se_scope}; } } return {}; diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index d21fdb6abe198..3b40f410e53ba 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -28,6 +28,8 @@ #include #include +#include + namespace tvm { namespace relay { @@ -35,41 +37,43 @@ namespace relay { const Op& DeviceCopyOp(); /*! - * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on - * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and + * stored at \p src_se_scope but then copied to \p dst_se_scope. */ -Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); /*! - * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on - * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. - * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and + * stored at \p src_se_scope but then copied to \p dst_se_scope.However, return \p expr + * directly if \p src_se_scope and \p dst_se_scope are (structurally) the same. */ -Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); /*! \brief Result of \p GetDeviceCopyProps. */ struct DeviceCopyProps { Expr body; // = null - DLDeviceType src_dev_type = kInvalidDeviceType; - DLDeviceType dst_dev_type = kInvalidDeviceType; + SEScope src_se_scope = SEScope::FullyUnconstrained(); + SEScope dst_se_scope = SEScope::FullyUnconstrained(); DeviceCopyProps() = default; - DeviceCopyProps(const Expr& body, DLDeviceType srcDevType, DLDeviceType dstDevType) - : body(body), src_dev_type(srcDevType), dst_dev_type(dstDevType) {} + DeviceCopyProps(Expr body, SEScope src_se_scope, SEScope dst_se_scope) + : body(std::move(body)), + src_se_scope(std::move(src_se_scope)), + dst_se_scope(std::move(dst_se_scope)) {} }; /*! - * \brief Returns the body expression, source, and destination device types for \p call_node if it - * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType - * device types. + * \brief Returns the body expression, source, and destination \p SEScopes for \p call_node + * if it is a "device_copy" CallNode. Otherwise returns the null expression and unconstrained + * device and scopes. */ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); /*! - * \brief Returns the body expression, source, and destination device types for \p expr if it - * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType - * device types. + * \brief Returns the body expression, source, and destination \p SEScopes for \p expr if it + * is a "device_copy" Call. Otherwise returns the null expression and unconstrained device and + * scopes. */ DeviceCopyProps GetDeviceCopyProps(const Expr& expr); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 08e92b31965e4..0574fd50f4b67 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -32,12 +32,14 @@ #include #include +#include #include #include "../../transforms/infer_layout_utils.h" #include "../annotation/annotation.h" #include "../op_common.h" #include "../type_relations.h" +#include "on_device.h" namespace tvm { namespace relay { @@ -48,13 +50,12 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); // The passing value in attrs and args doesn't seem super great. // We should consider a better solution, i.e the type relation // being able to see the arguments as well? -Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint) { +Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint) { auto attrs = make_object(); attrs->dtype = dtype_hint; - attrs->device_id = dev.device_id; - attrs->device_type = dev.device_type; + attrs->se_scope = std::move(se_scope); static const Op& op = Op::Get("memory.alloc_storage"); - return Call(op, {size, alignment}, Attrs(attrs), {}); + return Call(op, {std::move(size), std::move(alignment)}, Attrs(std::move(attrs)), {}); } TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage").set_body_typed(AllocStorage); diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index 558c409782f57..618044a9f2ca3 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -25,6 +25,8 @@ #ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_ #define TVM_RELAY_OP_MEMORY_MEMORY_H_ +#include + #include #include "tvm/relay/expr.h" @@ -32,7 +34,7 @@ namespace tvm { namespace relay { -Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint); +Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc new file mode 100644 index 0000000000000..9541d4122a2f9 --- /dev/null +++ b/src/relay/op/memory/on_device.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/op/memory/on_device.cc + * \brief Helpers for working with the "on_device" 'annotation' call. + */ + +#include "./on_device.h" + +#include +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); + +const Op& OnDeviceOp() { + static const Op& op = Op::Get("on_device"); + return op; +} + +Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) { + ICHECK(!se_scope->IsFullyUnconstrained()); + auto attrs = make_object(); + attrs->se_scope = std::move(se_scope); + attrs->is_fixed = is_fixed; + Span span = expr->span; + return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + std::move(span)); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); + +Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) { + if (se_scope->IsFullyUnconstrained()) { + // Nothing to annotate with. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // These operators are device polymorphic so no annotation is required. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // The device can be recovered from the binding site of the global or local variable. + return expr; + } + if (expr->IsInstance()) { + // If a primitive function then it is device polymorphic. Otherwise the device is captured + // by the function's "result_se_scope" attribute. + return expr; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (props.body.defined()) { + // Don't nest on_devices. + // If the inner and outer device types differ then we need to be careful: + // - If the inner on_device is_fixed then it disagrees with the outer. + // - If the outer on_device is_fixed then it implies a hidden device_copy + // Otherwise just use the inner device type and ignore the outer. + ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed)); + return OnDevice(props.body, se_scope, is_fixed || props.is_fixed); + } + return OnDevice(expr, std::move(se_scope), is_fixed); +} + +RELAY_REGISTER_OP("on_device") + .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.OnDeviceAttrs") + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("TNonComputational", true); + +OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + if (call_node->op == OnDeviceOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; + ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; + const auto* on_device_attrs = call_node->attrs.as(); + ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; + // Follow nesting: + // on_device(on_device(expr, se_scope=S), se_scope=T) == {expr, S} + auto inner = GetOnDeviceProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.se_scope, on_device_attrs->is_fixed || inner.is_fixed}; + } else { + return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->is_fixed}; + } + } + return {}; +} + +OnDeviceProps GetOnDeviceProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetOnDeviceProps(call_node); + } + return {}; +} + +Function FunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope) { + return WithAttrs(std::move(function), {{tvm::attr::kParamSEScopes, std::move(param_se_scopes)}, + {tvm::attr::kResultSEScope, std::move(result_se_scope)}}); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); + +Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope) { + if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(), + [](const SEScope& se_scope) { return se_scope->IsFullyUnconstrained(); }) && + result_se_scope->IsFullyUnconstrained()) { + // Nothing to annotate. + return function; + } + return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope)); +} + +SEScope GetFunctionResultSEScope(const FunctionNode* function_node) { + auto opt_se_scope = function_node->GetAttr(tvm::attr::kResultSEScope); + return opt_se_scope.value_or(SEScope::FullyUnconstrained()); +} + +SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { + ICHECK_LT(i, function_node->params.size()) + << "param index " << i << " out of range for function of arity " + << function_node->params.size(); + auto opt_array = function_node->GetAttr>(tvm::attr::kParamSEScopes); + if (!opt_array) { + // No annotation. + return SEScope::FullyUnconstrained(); + } + ICHECK_EQ(opt_array.value().size(), function_node->params.size()) + << "annotation parameters do not match function arity"; + return opt_array.value()[i]; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h new file mode 100644 index 0000000000000..a7b6cb7cf52a5 --- /dev/null +++ b/src/relay/op/memory/on_device.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/on_device.h + * \brief Helpers for working with the "on_device" 'annotation' call. + */ +#ifndef TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ +#define TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "on_device" operator. */ +const Op& OnDeviceOp(); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed. + * + * See \p OnDeviceAttrs for an overview. + */ +Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed if the + * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention. + * This means we will NOT wrap if: + * - \p se_scope is full unconstrained, which signals there are no device annotations + * already in play. + * - \p expr is an operator or primitive function literal. These are device polymorphic. + * - \p expr is a non-primitive function literal. The device is captured by the + * "result_se_scope" attribute on the function itself. + * - \p expr is a global var. The device is on the function attributes the global is bound to. + * - \p expr is a local var. The device is tracked by the device aware visitors for us. + * - \p expr is a constructor. These are device polymorphic. + * + */ +Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + SEScope se_scope = SEScope::FullyUnconstrained(); + bool is_fixed = false; + + OnDeviceProps() = default; + + OnDeviceProps(Expr body, SEScope se_scope, bool isFixed) + : body(std::move(body)), se_scope(std::move(se_scope)), is_fixed(isFixed) {} +}; + +/*! + * \brief Returns the body expression, \p SEScope, and is_fixed field for \p call_node if it + * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained + * \p SEScope, and false. + */ +OnDeviceProps GetOnDeviceProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, \p SEScope, and is_fixed field for \p expr if it is an + * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope, + * and \p false. + */ +OnDeviceProps GetOnDeviceProps(const Expr& expr); + +/*! + * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns + * \p expr directly. + */ +inline Expr IgnoreOnDevice(const Expr& expr) { + OnDeviceProps props = GetOnDeviceProps(expr); + return props.body.defined() ? props.body : expr; +} + +/*! + * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through + * any "on_device" annotations. + */ +template +const NodeType* AsIgnoringOnDevice(const Expr& expr) { + const auto* node = expr.as(); + if (node != nullptr) { + return node; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (!props.body.defined()) { + return nullptr; + } + return props.body.as(); +} + +/*! + * \brief Returns \p function annotated with "param_se_scopes" and "result_se_scope" + * attributes capturing parameter and result \p SEScopes respectively. + */ +Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); + +/*! + * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and + * result \p SEScopes are unconstrained. + */ +Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope); + +/*! + * \brief Returns the \p SEScope for the resut of \p function_node, or the unconstrained + * \p SEScope if function does not have the "result_se_scope" annotation. + */ +SEScope GetFunctionResultSEScope(const FunctionNode* function_node); + +/*! + * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or + * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation. + */ +SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 38c3305d31941..e3d5a821c58e4 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -38,41 +38,51 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) if (maybe_mod) { for (const auto& pair : maybe_mod.value()->functions) { if (const auto* function_node = pair.second.as()) { - DLDeviceType device_type = GetFunctionResultDeviceType(function_node); - if (device_type != kInvalidDeviceType) { - global_var_device_types_.emplace(pair.first, device_type); + SEScope se_scope = GetFunctionResultSEScope(function_node); + if (!se_scope->IsFullyUnconstrained()) { + global_var_se_scopes_.emplace(pair.first, se_scope); } } } } } -DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { - auto props = GetOnDeviceProps(expr); +SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { + OnDeviceProps props = GetOnDeviceProps(expr); if (props.body.defined() && props.is_fixed) { - return props.device_type; + return props.se_scope; } else if (const auto* var_node = expr.as()) { // Lookup variable binding. - auto itr = var_device_types_.find(GetRef(var_node)); - if (itr != var_device_types_.end()) { + auto itr = var_se_scopes_.find(GetRef(var_node)); + if (itr != var_se_scopes_.end()) { return itr->second; } - // else: fallthrough to unknown + // else: fallthrough to unconstrained } else if (const auto* global_var_node = expr.as()) { // Lookup global variable. - auto itr = global_var_device_types_.find(GetRef(global_var_node)); - if (itr != global_var_device_types_.end()) { + auto itr = global_var_se_scopes_.find(GetRef(global_var_node)); + if (itr != global_var_se_scopes_.end()) { return itr->second; } - // else: fallthrough to unknown + // else: fallthrough to unconstrained + } else if (const auto* function_node = expr.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + if (!expr_se_scopes_.empty()) { + // Use the currently in-scope device type. + return expr_se_scopes_.back(); + } + // else: fallthrough to unconstrained + } else { + return GetFunctionResultSEScope(function_node); + } } else { - if (!expr_device_types_.empty()) { + if (!expr_se_scopes_.empty()) { // Use the currently in-scope device type. - return expr_device_types_.back(); + return expr_se_scopes_.back(); } - // else: fallthrough to unknown + // else: fallthrough to unconstrained } - return kInvalidDeviceType; + return SEScope::FullyUnconstrained(); } void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } @@ -82,34 +92,34 @@ void LexicalOnDeviceMixin::ExitFunctionBody() { --function_nesting_; } -void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { - if (device_type == kInvalidDeviceType) { +void LexicalOnDeviceMixin::PushSEScope(const SEScope& se_scope) { + if (se_scope->IsFullyUnconstrained()) { return; } - expr_device_types_.emplace_back(device_type); + expr_se_scopes_.emplace_back(se_scope); } -void LexicalOnDeviceMixin::PopDeviceType() { - if (expr_device_types_.empty()) { +void LexicalOnDeviceMixin::PopSEScope() { + if (expr_se_scopes_.empty()) { return; } - expr_device_types_.pop_back(); + expr_se_scopes_.pop_back(); } -void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { - if (device_type == kInvalidDeviceType) { +void LexicalOnDeviceMixin::PushBoundVar(Var var, const SEScope& se_scope) { + if (se_scope->IsFullyUnconstrained()) { return; } - ICHECK(var_device_types_.find(var) == var_device_types_.end()); - var_device_types_.emplace(std::move(var), device_type); + ICHECK(var_se_scopes_.find(var) == var_se_scopes_.end()); + var_se_scopes_.emplace(std::move(var), se_scope); } void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { - auto itr = var_device_types_.find(var); - if (itr == var_device_types_.end()) { + auto itr = var_se_scopes_.find(var); + if (itr == var_se_scopes_.end()) { return; } - var_device_types_.erase(itr); + var_se_scopes_.erase(itr); } // TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing @@ -122,17 +132,17 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -147,7 +157,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec). - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -164,13 +174,13 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { } void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { - auto props = GetOnDeviceProps(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed) { // Entering lexical scope of fixed "on_device" call. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopDeviceType(); + PopSEScope(); } else { DeviceAwareVisitExpr_(call_node); } @@ -208,17 +218,17 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); Expr result = DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -235,7 +245,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); expr = inner_let_node->body; @@ -255,14 +265,14 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { } Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { - auto props = GetOnDeviceProps(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed) { // Entering lexical scope of fixed "on_device" call. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopDeviceType(); - return MaybeOnDevice(expr, props.device_type, props.is_fixed); + PopSEScope(); + return MaybeOnDevice(expr, props.se_scope, props.is_fixed); } else { return DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 3f4c5c24481e7..8cdf0db74ebd3 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -35,13 +35,14 @@ #include #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { namespace transform { /*! - * \brief Helper class for expression transformers which need to keep track of the device + * \brief Helper class for expression transformers which need to keep track of the \p SEScope * holding the results of expressions. This is recovered from function attributes and "on_device" * CallNodes added by the PlanDevices pass. * @@ -52,11 +53,11 @@ class LexicalOnDeviceMixin { explicit LexicalOnDeviceMixin(const Optional& maybe_mod); /*! - * \brief Returns the device type on which the result of \p expr should/will be stored, assuming - * Push/Pop DeviceType/BoundVar have been correctly called. May return \p kInvalidDeviceType if - * the device planning pass has not been run. + * \brief Returns the \p SEScope on which the result of \p expr should/will be stored, assuming + * {Push,Pop}{SEScope,BoundVar} have been correctly called. May return the unconstrained + * \p SEScope if the device planning pass has not been run. */ - DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + SEScope GetSEScope(const Expr& expr) const; /*! \brief Indicate a function body is being entered. */ void EnterFunctionBody(); @@ -64,19 +65,19 @@ class LexicalOnDeviceMixin { /*! \brief Indicate a function body has been processed. */ void ExitFunctionBody(); - /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ - void PushDeviceType(DLDeviceType device_type); + /*! \brief Push an \p SEScope onto the lexical SEScope stack. Ignore if unconstrained. */ + void PushSEScope(const SEScope& se_scope); - /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ - void PopDeviceType(); + /*! \brief Pop an \p SEScope from the lexical SEScope stack. Ignore if stack is empty. */ + void PopSEScope(); - /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + /*! \brief Remember that \p var will be stored at \p se_scope. Ignore if unconstrained. * * CAUTION: Despite the name we don't support re-entering the same function body. */ - void PushBoundVar(Var var, DLDeviceType device_type); + void PushBoundVar(Var var, const SEScope& se_scope); - /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + /*! \brief Remove the binding for \p var to its \p SEScope. Ignore if var is not bound. */ void PopBoundVar(const Var& var); /*! @@ -92,36 +93,36 @@ class LexicalOnDeviceMixin { int function_nesting_ = 0; /*! - * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. - * When visiting an expression other than a variable we can assume the expression's result is to - * be stored on device_type_.back(). + * \brief The stack of lexically enclosing "on_device" \p SEScopes, from outermost to + * innermost. When visiting an expression other than a variable we can assume the expression's + * result is to be stored on \p expr_se_scopes.back(). */ - std::vector expr_device_types_; + std::vector expr_se_scopes_; /*! - * \brief A map from in-scope local variables to their device types. We may assume the variable is - * only ever bound to a value stored on this device at runtime. + * \brief A map from in-scope local variables to their \p SEScopes. We may assume the variable is + * only ever bound to a value stored on this \p SEScope at runtime. * * Note: We're playing it safe and keying by object refs here just in case the Relay expression * being rewritten has no module or other global to keep it alive. */ - std::unordered_map - var_device_types_; + std::unordered_map var_se_scopes_; /*! - * \brief A map from global variables to their device types, ie the "result_device_type" of the - * function they are bound to in the module we are working on. We calculate this explicitly so - * that we don't neeed to hold on to any module, which is often in the process of being rewritten. + * \brief A map from global variables to their \p SEScopes, ie the "result_se_scope" of the + * function they are bound to in the module we are working on. We calculate and store this + * explicitly so that we don't need to hold on to any module, which is often in the process of + * being rewritten. */ - std::unordered_map - global_var_device_types_; + std::unordered_map + global_var_se_scopes_; }; template class DeviceAwareExprFunctor; /*! - * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * \brief ExprFunctor which tracks \p SEScopes. We only support 'visitor' style implementation * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without * any memoization. */ @@ -142,17 +143,17 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -167,7 +168,7 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -185,20 +186,20 @@ class DeviceAwareExprFunctor : public ExprFunctor : public ExprFunctor& maybe_mod) @@ -255,7 +256,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { void VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ @@ -269,7 +270,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the * let-bound variable will be tracked automatically. Default implementation just visits var and * value. */ @@ -288,7 +289,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PostVisitLetBlock_(const LetNode* let_node); }; -/*! \brief ExprMutator which tracks devices. */ +/*! \brief ExprMutator which tracks \p SEScopes. */ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { public: explicit DeviceAwareExprMutator(const Optional& maybe_mod) @@ -299,7 +300,7 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { Expr VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ @@ -313,7 +314,7 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the * let-bound variable will be tracked automatically. Default implementation just visits var and * value. */ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 15784856edbf5..305ee3dddbc48 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -28,6 +28,7 @@ #include "../op/annotation/annotation.h" #include "../op/memory/device_copy.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { @@ -35,11 +36,6 @@ namespace transform { namespace { -// Ye olde boost hash mixer. -constexpr size_t mix(size_t h1, size_t h2) { - return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); -} - /*! * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather * than the original "device_copy" operator. @@ -51,77 +47,57 @@ DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { if (tir_call_attrs == nullptr) { return {}; } - if (tir_call_attrs->metadata.count("source_device") != 1 || - tir_call_attrs->metadata.count("dst_device") != 1) { + if (tir_call_attrs->metadata.count("src_se_scope") != 1 || + tir_call_attrs->metadata.count("dst_se_scope") != 1) { return {}; } ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; - return { - call_node->args[0], - static_cast( - Downcast(tir_call_attrs->metadata["source_device"])->value), - static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; + return {call_node->args[0], Downcast(tir_call_attrs->metadata["src_se_scope"]), + Downcast(tir_call_attrs->metadata["dst_se_scope"])}; } } // namespace -// The following hash and equality helpers give each free first-order domain pointer its own -// distinct identity. - -size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { - if (domain->is_free()) { - // Give each free first-order domain its own identity. - return static_cast(reinterpret_cast(domain.get())); - } else { - size_t h = domain->args_and_result_.size(); - h = mix(h, std::hash()(static_cast(domain->device_type_))); - for (const auto& sub_domain_ptr : domain->args_and_result_) { - h = mix(h, DeviceDomainHash()(sub_domain_ptr)); - } - return h; - } +DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) { + host_domain_ = MakeFirstOrderDomain(config_->host_se_scope); } -bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { - if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { - // Mismatched arities are never equal. - // (Though we'll never ask to do such a comparison explicitly, the hash map - // may do so implicitly due to hash collisions.) - return false; - } - if (lhs->is_free() && rhs->is_free()) { - // Compare first-order free domains by their address. - return lhs.get() == rhs.get(); - } - if (lhs->args_and_result_.empty()) { - // Compare first-order domains by their device type -- free vs bound will compare as false. - return lhs->device_type_ == rhs->device_type_; - } else { - // Compare higher-order domains pointwise. - for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { - if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { - return false; - } +DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const SEScope& se_scope) { + if (se_scope->IsFullyConstrained()) { + auto itr = fully_constrained_se_scope_to_domain_.find(se_scope); + if (itr != fully_constrained_se_scope_to_domain_.end()) { + return itr->second; } - return true; + DeviceDomainPtr domain = std::make_shared(se_scope); + fully_constrained_se_scope_to_domain_.emplace(se_scope, domain); + return domain; + } else { + return std::make_shared(se_scope); } } -/* static */ -DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const SEScope& se_scope) { if (const auto* func_type_node = type.as()) { std::vector args_and_result; args_and_result.reserve(func_type_node->arg_types.size() + 1); for (const auto& arg_type : func_type_node->arg_types) { - args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + args_and_result.emplace_back(MakeDomain(arg_type, SEScope::FullyUnconstrained())); } - args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, se_scope)); return std::make_shared(std::move(args_and_result)); } else { - return std::make_shared(device_type); + return MakeFirstOrderDomain(se_scope); } } +DeviceDomainPtr DeviceDomains::ForSEScope(const Type& type, const SEScope& non_canonical_se_scope) { + // Generally se_scope will have come from an annotation so resolve it to ensure we have + // its canonical representation. + SEScope se_scope = config_->CanonicalSEScope(non_canonical_se_scope); + ICHECK(!se_scope->IsFullyUnconstrained()); + return MakeDomain(type, se_scope); +} + DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { DeviceDomainPtr root = domain; while (true) { @@ -144,56 +120,82 @@ DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { return root; } -DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - // TODO(mbs): Proper diagnostics. +DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (lhs == rhs) { + return lhs; + } ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) << "Device domains:" << std::endl << ToString(lhs) << std::endl << "and" << std::endl << ToString(rhs) << std::endl << "do not have the same kind and can't be unified."; - if (rhs->is_free()) { - return lhs; - } else if (lhs->is_free()) { - return rhs; - } else if (lhs->args_and_result_.empty()) { - // Must have consistent device types for first order domains. - if (lhs->device_type_ != rhs->device_type_) { - // TODO(mbs): Proper diagnostics. - std::ostringstream os; - os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; - throw Error(os.str()); + if (lhs->args_and_result_.empty()) { + // Directly compare first-order. + if (rhs->se_scope_->IsFullyUnconstrained()) { + return lhs; } - return lhs; + if (lhs->se_scope_->IsFullyUnconstrained()) { + return rhs; + } + Optional joined_se_scope = SEScope::Join(lhs->se_scope_, rhs->se_scope_); + if (!joined_se_scope) { + return nullptr; + } + return MakeFirstOrderDomain(config_->CanonicalSEScope(joined_se_scope.value())); } else { // Recurse for higher-order. std::vector args_and_result; args_and_result.reserve(lhs->args_and_result_.size()); for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { - args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + DeviceDomainPtr joined_domain = + UnifyOrNull(lhs->args_and_result_[i], rhs->args_and_result_[i]); + if (joined_domain == nullptr) { + return nullptr; + } + args_and_result.emplace_back(std::move(joined_domain)); } - return MakeDomain(std::move(args_and_result)); + return MakeHigherOrderDomain(std::move(args_and_result)); } } -DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { +DeviceDomainPtr DeviceDomains::UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + ICHECK_NOTNULL(lhs); + ICHECK_NOTNULL(rhs); lhs = Lookup(lhs); rhs = Lookup(rhs); - auto joined_domain = Join(lhs, rhs); - if (!DeviceDomainEqual()(lhs, joined_domain)) { + DeviceDomainPtr joined_domain = JoinOrNull(lhs, rhs); + if (joined_domain == nullptr) { + return nullptr; + } + if (lhs != joined_domain) { domain_to_equiv_.emplace(lhs, joined_domain); } - if (!DeviceDomainEqual()(rhs, joined_domain)) { + if (rhs != joined_domain) { domain_to_equiv_.emplace(rhs, joined_domain); } return joined_domain; } -void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - if (!lhs->is_higher_order() && rhs->is_higher_order()) { - Collapse(lhs, rhs); +bool DeviceDomains::CollapseOrFalse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + ICHECK(!first_order_domain->is_higher_order()); + ICHECK(higher_order_domain->is_higher_order()); + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + if (UnifyOrNull(higher_order_domain->function_param(i), first_order_domain) == nullptr) { + return false; + } + } + return UnifyOrNull(higher_order_domain->function_result(), first_order_domain) != nullptr; +} + +bool DeviceDomains::UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, + const DeviceDomainPtr& rhs_maybe_higher_order) { + ICHECK(!lhs_first_order->is_higher_order()); + if (rhs_maybe_higher_order->is_higher_order()) { + return CollapseOrFalse(lhs_first_order, rhs_maybe_higher_order); } else { - Unify(lhs, rhs); + return UnifyOrNull(lhs_first_order, rhs_maybe_higher_order) != nullptr; } } @@ -215,49 +217,49 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { } std::vector args_and_result; - auto on_device_props = GetOnDeviceProps(call.get()); - auto device_copy_props = GetDeviceCopyProps(call.get()); + OnDeviceProps on_device_props = GetOnDeviceProps(call.get()); + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); if (!device_copy_props.body.defined()) { + // Special case for the TIR-ified version of "device_copy". device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); } if (on_device_props.body.defined()) { - // on_device(expr, device_type=, is_fixed=false) + // on_device(expr, se_scope=, is_fixed=false) // on_device : fn():?x? // - // on_device(expr, device_type=, is_fixed=true) + // on_device(expr, se_scope=, is_fixed=true) // on_device: fn(): args_and_result.emplace_back( - ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); if (on_device_props.is_fixed) { args_and_result.emplace_back(args_and_result.front()); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } } else if (device_copy_props.body.defined()) { - // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy(expr, src_se_scope=, dst_se_scope=) // device_copy: fn(): args_and_result.emplace_back( - ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + ForSEScope(device_copy_props.body->checked_type(), device_copy_props.src_se_scope)); args_and_result.emplace_back( - ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + ForSEScope(device_copy_props.body->checked_type(), device_copy_props.dst_se_scope)); } else if (call->op == alloc_storage_op) { ICHECK_EQ(call->args.size(), 2U); - // alloc_storage(size, alignment, device_type=) + // alloc_storage(size, alignment, se_scope=) // alloc_storage: fn(, ): const auto* attrs = call->attrs.as(); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back( - ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(ForSEScope(call->checked_type(), attrs->se_scope)); } else if (call->op == alloc_tensor_op) { ICHECK_EQ(call->args.size(), 3U); // alloc_tensor(storage, offset, shape) // alloc_tensor: fn(?x?, , ):?x? auto free_domain = Free(call->checked_type()); args_and_result.emplace_back(free_domain); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); args_and_result.emplace_back(free_domain); } else if (call->op == shape_func_op) { ICHECK_EQ(call->args.size(), 3U); @@ -267,15 +269,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(Free(call->args[0]->checked_type())); // TODO(mbs): I think this should be on the cpu only when is_input = [false], but // what do we do when we have multiple arguments with different is_input values? - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); } else if (call->op == shape_of_op) { ICHECK_EQ(call->args.size(), 1U); // shape_of(tensor) // shape_of: fn(?x?): args_and_result.emplace_back(Free(call->args[0]->checked_type())); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); } else if (call->op == invoke_tvm_op) { ICHECK_EQ(call->args.size(), 3U); // invoke_tvm_op(op, inputs, outputs) @@ -292,13 +294,13 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // reshape_tensor: fn(?x?, ):?x? auto free_domain = Free(call->checked_type()); args_and_result.emplace_back(free_domain); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); args_and_result.emplace_back(free_domain); } else if (call->op->IsInstance()) { // (arg1, ..., argn) // : fn(?x?, ..., ?x?):?x? // (all args and result must be first-order). - auto free_domain = Free(arb_); + auto free_domain = MakeFirstOrderDomain(SEScope::FullyUnconstrained()); for (size_t i = 0; i < call->args.size(); ++i) { args_and_result.emplace_back(free_domain); } @@ -314,8 +316,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); auto result_domain = Free(func_type_node->ret_type); // first-order for (const auto& arg_type : func_type_node->arg_types) { - auto param_domain = Free(arg_type); // possibly higher-order - UnifyCollapsed(result_domain, param_domain); // collapse if required + auto param_domain = Free(arg_type); // possibly higher-order + bool success = UnifyCollapsedOrFalse(result_domain, param_domain); // collapse if required + ICHECK(success); args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); @@ -323,7 +326,7 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // Defer to normal case where op can be an arbitrary expression. return DomainFor(call->op); } - auto domain = MakeDomain(std::move(args_and_result)); + auto domain = MakeHigherOrderDomain(std::move(args_and_result)); call_to_callee_domain_.emplace(call.get(), domain); return domain; } @@ -331,111 +334,104 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { auto lhs_domain = DomainFor(lhs); auto rhs_domain = DomainFor(rhs); - try { - Unify(lhs_domain, rhs_domain); - } catch (const Error& e) { + if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + LOG(FATAL) << "Incompatible SEScopes for expressions:" << std::endl << PrettyPrint(lhs) << std::endl - << "with device:" << std::endl + << "with scope:" << std::endl << ToString(lhs_domain) << "and:" << std::endl << PrettyPrint(rhs) << std::endl - << "with device:" << std::endl - << ToString(rhs_domain) << std::endl - << e.what(); + << "with scope:" << std::endl + << ToString(rhs_domain); } } void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { auto actual_domain = DomainFor(expr); - try { - Unify(actual_domain, expected_domain); - } catch (const Error& e) { + if (UnifyOrNull(actual_domain, expected_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expression:" << std::endl + LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl << PrettyPrint(expr) << std::endl - << "with actual device:" << std::endl + << "with actual scope:" << std::endl << ToString(actual_domain) << std::endl - << "and expected device:" << std::endl - << ToString(expected_domain) << std::endl - << e.what(); + << "and expected scope:" << std::endl + << ToString(expected_domain); } } -void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { - auto actual_domain = DomainFor(expr); - try { - UnifyCollapsed(actual_domain, expected_domain); - } catch (const Error& e) { +void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order, + const DeviceDomainPtr& expected_domain_maybe_higher_order) { + auto actual_domain_first_order = DomainFor(expr_first_order); + if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expression:" << std::endl - << PrettyPrint(expr) << std::endl - << "with actual device:" << std::endl - << ToString(actual_domain) << std::endl - << "and expected device:" << std::endl - << ToString(expected_domain) << std::endl - << e.what(); + LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl + << PrettyPrint(expr_first_order) << std::endl + << "with actual scope:" << std::endl + << ToString(actual_domain_first_order) << std::endl + << "and expected scope:" << std::endl + << ToString(expected_domain_maybe_higher_order); } } -bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { +bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) { domain = Lookup(domain); - if (domain->is_free()) { - return true; - } - for (const auto& sub_domain : domain->args_and_result_) { - if (AnyFree(sub_domain)) { - return true; - } - } - return false; -} - -void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, - const DeviceDomainPtr& higher_order_domain) { - for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { - Unify(higher_order_domain->function_param(i), first_order_domain); + if (domain->args_and_result_.empty()) { + // First-order. + return domain->se_scope_->IsFullyConstrained(); + } else { + // Higher-order. + return std::all_of( + domain->args_and_result_.begin(), domain->args_and_result_.end(), + [this](const DeviceDomainPtr& sub_domain) { return IsFullyConstrained(sub_domain); }); } - Unify(higher_order_domain->function_result(), first_order_domain); } -void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { - ICHECK_NE(default_device_type, kInvalidDeviceType); +void DeviceDomains::SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope) { + ICHECK(!default_se_scope->IsFullyUnconstrained()); domain = Lookup(domain); - if (domain->is_free()) { - // Will never throw since lhs is free. - Unify(domain, std::make_shared(default_device_type)); - } else if (!domain->args_and_result_.empty()) { + if (domain->args_and_result_.empty()) { + DeviceDomainPtr defaulted_domain_ptr = + UnifyOrNull(domain, MakeFirstOrderDomain(config_->CanonicalSEScope( + SEScope::Default(domain->se_scope_, default_se_scope)))); + ICHECK_NOTNULL(defaulted_domain_ptr); + } else { for (const auto& sub_domain : domain->args_and_result_) { - SetDefault(sub_domain, default_device_type); + SetDefault(sub_domain, default_se_scope); } } } -void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, - DLDeviceType default_device_type) { - if (!domain->is_higher_order()) { - SetDefault(domain, default_device_type); - return; +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, + const SEScope& default_se_scope) { + if (domain_maybe_higher_order->args_and_result_.empty()) { + SetDefault(domain_maybe_higher_order, default_se_scope); + } else { + // First set default for result domain. + SetDefault(ResultDomain(domain_maybe_higher_order), default_se_scope); + // Then use current result domain as default for everything else. + SetDefault(domain_maybe_higher_order, ResultSEScope(domain_maybe_higher_order)); } - DLDeviceType result_device_type = ResultDeviceType(domain); - if (result_device_type == kInvalidDeviceType) { - // If the function result device is still free use the given default. - result_device_type = default_device_type; +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); } - // Default any remaining free parameters to the function result device. - SetDefault(domain, result_device_type); + return domain; } std::string DeviceDomains::ToString(DeviceDomainPtr domain) { domain = Lookup(domain); std::ostringstream os; - if (domain->is_free()) { - // first-order free - os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; - } else if (domain->args_and_result_.empty()) { - // first-order bound - os << "<" << domain->device_type_ << ">"; + if (domain->args_and_result_.empty()) { + // First-order. + if (!domain->se_scope_->IsFullyConstrained()) { + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } + if (!domain->se_scope_->IsFullyUnconstrained()) { + os << domain->se_scope_; + } } else { // higher-order os << "fn("; @@ -469,14 +465,6 @@ std::string DeviceDomains::ToString() { return os.str(); } -DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { - domain = Lookup(domain); - while (!domain->args_and_result_.empty()) { - domain = Lookup(domain->args_and_result_.back()); - } - return domain; -} - } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h index a29370a0e8077..f3f31e790983b 100644 --- a/src/relay/transforms/device_domains.h +++ b/src/relay/transforms/device_domains.h @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -42,13 +44,14 @@ namespace transform { class DeviceDomain; using DeviceDomainPtr = std::shared_ptr; +class DeviceDomains; /*! * \brief Represents the domain over which we collect equality constraints. * * \code * D ::= ?x? -- first order, free - * | -- first order, bound + * | -- first order, bound to specific device and memory scope * | fn(D1, ..., Dn):Dr -- higher order * \endcode * @@ -56,44 +59,46 @@ using DeviceDomainPtr = std::shared_ptr; * a notion of the 'result domain' of a domain: * \code * result_domain(?x?) = ?x? - * result_domain() = + * result_domain() = * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) * \endcode */ class DeviceDomain { public: /*! - * \brief Constructs a first-order domain of \p device_type, which may be - * \p kInvalidDeviceType to indicate the domain is free. + * \brief Constructs a first-order domain for \p se_scope, which may be + * fully free (ie se_scope is unconstrained), partially free (ie se_scope has at least on + * of its target, device id or memory scopes known), or fully fixed (ie se_scope has its target, + * device id and memory scopes set). + * + * CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor. */ - explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + explicit DeviceDomain(SEScope se_scope) : se_scope_(std::move(se_scope)) {} /*! * \brief Constructs a higher-order domain, where \p args_and_result contain the * function argument and result domains in order. + * + * CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor. */ explicit DeviceDomain(std::vector args_and_result) - : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + : se_scope_(SEScope::FullyUnconstrained()), args_and_result_(std::move(args_and_result)) {} - /*! \brief Returns true if domain is first-order and free. */ - bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } - - /*! \brief Returns true if domain is higher-order. */ bool is_higher_order() const { return !args_and_result_.empty(); } - DLDeviceType first_order_device_type() const { - ICHECK(args_and_result_.empty()); - return device_type_; + SEScope first_order_se_scope() const { + ICHECK(args_and_result_.empty()) << "expecting domain to be first-order"; + return se_scope_; } size_t function_arity() const { - ICHECK(!args_and_result_.empty()); + ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order"; return args_and_result_.size() - 1UL; } DeviceDomainPtr function_param(size_t i) const { - ICHECK(!args_and_result_.empty()); - ICHECK_LT(i + 1, args_and_result_.size()); + ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order"; + ICHECK_LT(i + 1, args_and_result_.size()) << "parameter index is out of range"; return args_and_result_[i]; } @@ -104,11 +109,12 @@ class DeviceDomain { private: /*! - * \brief If this is a function domain then always kInvalidDevice. Otherwise will be - * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is - * bound. + * \brief If this is a function domain then always fully unconstrained. Otherwise will be + * fully unconstrained (the domain is still completely free), partially constrained + * (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and + * \p memory_scope are still unconstrained), or fully constrained (everything is known). */ - const DLDeviceType device_type_; + const SEScope se_scope_; /*! * \brief If this is a function domain then the sub-domains for each of the function's @@ -116,81 +122,92 @@ class DeviceDomain { */ const std::vector args_and_result_; - friend struct DeviceDomainHash; - friend struct DeviceDomainEqual; friend class DeviceDomains; }; -// The following hash and equality helpers give each free first-order domain pointer its own -// distinct identity. -struct DeviceDomainHash { - size_t operator()(const DeviceDomainPtr& domain) const; -}; - -struct DeviceDomainEqual { - public: - bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; -}; - /*! * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation - * built up by calls to \p Unify. + * built up by calls to \p UnifyOrNull. */ class DeviceDomains { public: - DeviceDomains() = default; + explicit DeviceDomains(CompilationConfig config); + + const CompilationConfig& config() const { return config_; } /*! - * \brief Returns a domain appropriate for \p type who's result domain is bound - * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain - * will be free. + * \brief Returns the domain representing \p se_scope. If \p se_scope is fully constrained + * then the domain will be unique that \p se_scope. */ - static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + DeviceDomainPtr MakeFirstOrderDomain(const SEScope& se_scope); /*! * \brief Returns a higher-order domain with \p args_and_results. */ - static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + DeviceDomainPtr MakeHigherOrderDomain(std::vector arg_and_results) { return std::make_shared(std::move(arg_and_results)); } - /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ - static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { - ICHECK_NE(device_type, kInvalidDeviceType); - return MakeDomain(type, device_type); - } + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound to \p se_scope. + * If \p type is a function then all parameter domains will be completely free. It is valid for + * \p se_scope to be fully unconstrained. + */ + DeviceDomainPtr MakeDomain(const Type& type, const SEScope& se_scope); + + /*! + * \brief Returns a domain with the given result appropriate \p non_canonical_se_scope, + * which cannot be fully unconstrained. We first canonicalize the scope to unsure it has + * a target and is unique. + */ + DeviceDomainPtr ForSEScope(const Type& type, const SEScope& non_canonical_se_scope); /*! \brief Returns a free domain appropriate for \p type. */ - static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, SEScope::FullyUnconstrained()); } /*! \brief Returns the domain representing the equivalence class containing \p domain. */ DeviceDomainPtr Lookup(DeviceDomainPtr domain); /*! - * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. - * - * Throws \p Error on failure. + * \brief Returns the most constrained domain which agrees with both \p lhs and \p rhs. Returns + * null if no such domain exists, ie some first-order component of \p lhs is constrained + * differently than the corresponding component of \p rhs. */ - DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + DeviceDomainPtr JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); /*! - * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p - * rhs disagree on bound device type. - * - * Throws \p Error on failure. + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Returns null if + * \p lhs and \p rhs are not unifiable. */ // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but // given we have refs to functions I'm prepared to be surprised. - DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + DeviceDomainPtr UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Returns false if any unification fails. + */ + bool CollapseOrFalse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); /*! - * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, - * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as - * \p Unify. + * \brief Unifies \p lhs_first_order and \p rhs_maybe_higher_order. If \p rhs_maybe_higher_order + * is indeed higher-order, require all of its arguments and result to unify with + * \p lhs_first_order. Otherwise same as \p Unify. Returns false if unification is not possible. * - * Throws \p Error on failure. + * In an expression such as: + * \code + * (fn(...) {...}, ...).0 + * \endcode + * we need to force all the devices of the inner function to be the same as the device for the + * overall tuple since the device domain does not understand tuples. Similarly for references + * and ADTs. */ - void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + bool UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, + const DeviceDomainPtr& rhs_maybe_higher_order); /*! \brief Returns true if a domain is known for \p expr. */ bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } @@ -204,7 +221,8 @@ class DeviceDomains { * DomainFor(call->op). * * This special handling is needed: - * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given + * devices. * - To handle some special ops which constrain devices to the CPU. * - To allow the same primitive to be called on different devices at different call sites. * Since each call to the op can have a different domain we index the ops by the call expression @@ -212,11 +230,17 @@ class DeviceDomains { */ DeviceDomainPtr DomainForCallee(const Call& call); - /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + /*! + * \brief Unifies the domains for expressions \p lhs and \p rhs. + * + * Aborts if unification fails. + */ void UnifyExprExact(const Expr& lhs, const Expr& rhs); /*! * \brief Unifies the domain for \p expr with \p expected_domain. + * + * Aborts if unification fails. */ void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); @@ -224,37 +248,25 @@ class DeviceDomains { * \brief Unifies the domain for \p expr with \p expected_domain. * If \p expected_domain is higher-order but \p expr is first-order, require all arguments * and the result of \p expected_domain to have the same domain as for \p expr. - */ - void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); - - /*! \brief Returns true if \p domain contains any free sub-domains. */ - bool AnyFree(DeviceDomainPtr domain); - - /* - * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. - * This can be used to handle functions within tuples, references and ADTs since we don't - * attempt to track anything beyond 'the device' for expressions of those first-order types. * - * Throws \p Error on failure. + * Aborts if unification fails. */ - void Collapse(const DeviceDomainPtr& first_order_domain, - const DeviceDomainPtr& higher_order_domain); + void UnifyExprCollapsed(const Expr& expr_first_order, + const DeviceDomainPtr& expected_domain_maybe_higher_order); - /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ - void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + /*! \brief Returns true if \p domain is fully constrainted. */ + bool IsFullyConstrained(DeviceDomainPtr domain); + + /*! \brief Force all \p SEScopes in \p domain to default to \p default_se_scope. */ + void SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope); /*! - * \brief If \p domain is higher-order and its result domain is free, force it to - * \p default_device_type. Then force any remaining free domains to the result domain - * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + * \brief If \p domain is higher-order default it's result domain to \p default_se_scope. + * Then force all remaining \p SEScopes to the result domain (freshly defaulted or original). + * If \p domain is first-order same as \p SetDefault. */ - void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); - - /*! \brief Returns one-line description of \p domain for debugging. */ - std::string ToString(DeviceDomainPtr domain); - - /*! \brief Returns description of entire system of constraints for debugging */ - std::string ToString(); + void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, + const SEScope& default_se_scope); /*! * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). @@ -262,13 +274,19 @@ class DeviceDomains { DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); /*! - * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain - * comment). + * \brief Returns the result \p SEScope (possibly unconstrained) for \p domain + * (see defn in DeviceDomain comment). */ - DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { - return ResultDomain(domain)->first_order_device_type(); + SEScope ResultSEScope(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_se_scope(); } + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + private: /*! \brief Intrinsics we need to handle specially. */ const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); @@ -277,12 +295,14 @@ class DeviceDomains { const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); const Op& shape_func_op = Op::Get("vm.shape_func"); const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); - /*! \brief The CPU device type for special operators such as dynamic shape functions. */ - const DLDeviceType cpu_device_type_ = kDLCPU; - /*! \brief Placeholder for any first-order type. */ - Type arb_ = TupleType(); - /*! \brief The domain for first-order expressions on the CPU. */ - DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + CompilationConfig config_; + + /*! + * \brief The domain for first-order expressions of non-tensor type, such as shapes and + * buffer dimensions. Generally this will be a CPU. + */ + DeviceDomainPtr host_domain_; /*! \brief Maps expressions to their domains as determined during analysis. */ std::unordered_map expr_to_domain_; @@ -293,8 +313,19 @@ class DeviceDomains { std::unordered_map call_to_callee_domain_; /*! \brief Maps device domains to their equivalent domains as determined during unification. */ - std::unordered_map - domain_to_equiv_; + std::unordered_map domain_to_equiv_; + + /*! + * \brief Maps fully constrained \p SEScopes to their corresponding domains. By sharing those + * domains we can ensure: + * + * \code + * domain0 != domain1 && domain0 fully constrained && domain1 fully constrained + * ==> domain0 and domain1 are incompatible + * \endcode + */ + std::unordered_map + fully_constrained_se_scope_to_domain_; }; } // namespace transform diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 83429a9e616f0..d6ab566a336ee 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -19,22 +19,22 @@ /*! * \file src/relay/transforms/device_planner.cc - * \brief Determines a unique device to hold the result of every Relay sub-expression. + * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. - * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the - * specific target associated with D (this is recovered independently via a TargetMap), and we - * do not track the storage scope within D (this is yet to be implemented). + * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device + * of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's + * code is compiled with a specific \p Target. * * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', * see below. * * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: - * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and - * 'dst_dev_type' device type, which constrain the argument and context of the call + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_se_scope' and + * 'dst_se_scope' \p SEScopes, which constrain the argument and context of the call * respectively. It is ok if source and destination devices are the same, such no-op copies * will be removed after accounting for the device preference. - * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', which * constrains the argument of the call, but (usually, see below) leaves the context * unconstrained. These are called 'annotations' in the rest of the code, have no operational * significance by themselves, but may trigger the insertion of a new "device_copy". @@ -63,15 +63,16 @@ * ------- * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see * below) to all other Relay sub-expressions. (For idempotence we also respect any existing - * "param_device_types" and "result_device_type" function attributes we introduce below.) + * "param_se_scopes" and "result_se_scope" function attributes we introduce below.) * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the * same device. However each call site can use a different device. In other words primitives are - * 'device polymorphic' since we compile and execute them for each required device. + * 'device polymorphic' since we compile and execute them for each required device. ADT constructors + * are similarly polymorphic. * * For most Relay expressions the device for the overall expression is the same as the device - * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple - * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall \p if, * and so on. * * Some special ops (or 'dialects') are handled: @@ -91,18 +92,18 @@ * * Phase 2 * ------- - * After flowing constraints we apply some defaulting heuristics (using a global default device) + * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) * to fix the device for any as-yet unconstrained sub-expressions. * - Unconstrained function result devices default to the global default device. * - Unconstrained function parameters devices default to the device for the function result. * - Unconstrained let-bound expression devices default to the device for the overall let. - * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to - * the global default device. Worth a design doc with motivating examples I think. + * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the + * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself. * * Phase 3 * ------- * Finally, the result of this analysis is reified into the result as: - * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * - Additional "param_se_scopes" (an \p Array) and "result_se_scope" (an \p SEScope) * attributes for every function (both top-level and local). These describe the devices for * the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the @@ -124,14 +125,15 @@ * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion * to ANF must respect the lexical scoping convention: * \code - * f(on_device(g(h(a, b), c), device_type=CPU)) + * f(on_device(g(h(a, b), c), se_scope=CPU)) * ==> - * let %x0 = on_device(h(a, b), device_type=CPU) - * let %x1 = on_device(g(%x0), device-type=CPU) - * f(on_device(%x1, device_type=CPU)) + * let %x0 = on_device(h(a, b), se_scope=CPU) + * let %x1 = on_device(g(%x0), se_scope=CPU) + * f(on_device(%x1, se_scope=CPU)) * \endcode * * This pass can be run before FuseOps it can use device-specific fusion rules. + * TODO(mbs): We also need to support running after FuseOps. * * 'Stored on' vs 'Executes on' * ---------------------------- @@ -147,7 +149,7 @@ * pass, but we'd like to fold that into device planning here to ensure everything is consistent. * * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay - * expression (eg an if expression) on one device even though the tensor data resides on + * expression (eg an \p if expression) on one device even though the tensor data resides on * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just * compile the function body for the function's result device. @@ -157,7 +159,7 @@ * minimize cross-device calls by moving device copies out of functions. E.g.: * \code * def @f() { // execute on CPU - * let x = on_device(...GPU computation..., device_type=GPU); + * let x = on_device(...GPU computation..., se_scope=GPU); * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) * } * def @main() { @@ -189,7 +191,7 @@ * \code * let f = fn(x, y) { ... } * let g = fn(f, z) { f(z, z) } - * g(f, on_device(..., device_type=CPU)) + * g(f, on_device(..., se_scope=CPU)) * \endcode * the parameters \p x and \p y will be on the CPU. * @@ -226,28 +228,16 @@ * `-- Mark's stamp of completeness :-) * * TODO(mbs): - * * Though on_device is the identity for all types we can't wrap it around functions/constructors - * taking type args (or at least not without changing type_infer.cc to see through them). - * This is not currently handled generally. * * Proper diagnostics for unification failure using spans. - * * Make sure the pass is idempotent even after FuseOps etc. - * * Support application of constructors properly. Are they device polymorphic? - * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. * * Support running the pass post FuseOps (so need to understand primitive functions, both * outlines and lined) and post the VM transforms (probably need to support more intrinsic * forms?). * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default * device for primitives vs the default device for the rest of Relay. - * * We'll probably need some support for partial 'device polymorphism' for functions once we - * incorporate targets and memory scopes into the domain. For example it's ok for the function - * body to be executed on different device ids provided they have the same target and memory - * scope. - * * Might be simpler to just let every type have a device annotation rather than work in - * a separate domain? + * * We may want some 'device polymorphism' for Relay functions. Eg it's ok for the function + * to be called with params/result on different (virtual) device ids provided the target and + * memory scopes are consistent. * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. - * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls - * in tuples at the top level of function bodies or main expression, irrespective of the - * "on_device" body. What's up with that? */ #include @@ -267,6 +257,7 @@ #include "../op/annotation/annotation.h" #include "../op/memory/device_copy.h" +#include "../op/memory/on_device.h" #include "./device_domains.h" namespace tvm { @@ -283,11 +274,11 @@ namespace { * \brief Rewrites "on_device" calls to handle some special cases. * * \code - * let %x = on_device(e, device_type=d) - * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * let %x = on_device(e, se_scope=d) + * ==> let %x = on_device(e, se_scope=d, is_fixed=True) * - * fn(%x) { on_device(e, device_type=d) } - * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * fn(%x) { on_device(e, se_scope=d) } + * ==> fn(%x) { on_device(e, se_scope=d, is_fixed=True) * * on_device(e).0 * ==> on_device(e.0) @@ -303,12 +294,12 @@ class RewriteOnDevices : public ExprMutator { // TODO(mbs): Avoid copy. Expr tuple_get_item = TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); - auto props = GetOnDeviceProps(tuple); + OnDeviceProps props = GetOnDeviceProps(tuple); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "wrapping tuple get item:" << std::endl + VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl - << "with \"on_device\" for device " << props.device_type; - return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + << "with \"on_device\" for SEScope " << props.se_scope; + return OnDevice(tuple_get_item, props.se_scope, /*is_fixed=*/false); } else { return tuple_get_item; } @@ -320,12 +311,12 @@ class RewriteOnDevices : public ExprMutator { while (const auto* inner_let_node = expr.as()) { Expr inner_let = GetRef(inner_let_node); Expr value = VisitExpr(inner_let_node->value); - auto props = GetOnDeviceProps(value); + OnDeviceProps props = GetOnDeviceProps(value); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "revising let-bound expression of let:" << std::endl + VLOG(2) << "revising let-bound expression of let:" << std::endl << PrettyPrint(expr) << std::endl - << "to be fixed to device " << props.device_type; - value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + << "to be fixed to SEScope " << props.se_scope; + value = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); expr = inner_let_node->body; @@ -341,12 +332,12 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const FunctionNode* function_node) final { Expr body = VisitExpr(function_node->body); - auto props = GetOnDeviceProps(body); + OnDeviceProps props = GetOnDeviceProps(body); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "revising body of function:" << std::endl + VLOG(2) << "revising body of function:" << std::endl << PrettyPrint(GetRef(function_node)) << std::endl - << "to be fixed to device " << props.device_type; - body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + << "to be fixed to SEScope " << props.se_scope; + body = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } // TODO(mbs): Avoid copy return Function(function_node->params, body, function_node->ret_type, @@ -363,12 +354,12 @@ class RewriteOnDevices : public ExprMutator { * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. * * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, - * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * from \code on_device(%x, se_scope=d) \endcode we know \p %x must be on device \p d, and thus * so must \p %y. * * Constraints can flow in interesting ways. E.g. in: * \code - * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %f = fn(%x, %y) { add(%x, on_device(%y, se_scope=d)) } * let %g = fn(%f, %x, %y) { %f(%x, %y) } * %g(%f, %a, %b) * \endcode @@ -376,8 +367,8 @@ class RewriteOnDevices : public ExprMutator { */ class DeviceAnalyzer : public ExprVisitor { public: - explicit DeviceAnalyzer(IRModule mod) - : mod_(std::move(mod)), domains_(std::make_unique()) {} + DeviceAnalyzer(IRModule mod, CompilationConfig config) + : mod_(std::move(mod)), domains_(std::make_unique(std::move(config))) {} /*! * \brief Returns the expression-to-device-domain map for all expressions in all the global @@ -387,7 +378,7 @@ class DeviceAnalyzer : public ExprVisitor { std::unique_ptr Analyze() { VLOG_CONTEXT << "DeviceAnalyzer"; for (const auto& pair : mod_->functions) { - VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; domains_->UnifyExprExact(pair.first, pair.second); VisitExpr(pair.second); } @@ -413,9 +404,9 @@ class DeviceAnalyzer : public ExprVisitor { } args_and_result_domains.emplace_back(domains_->DomainFor(call)); auto implied_domain = - DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); // higher-order - VLOG(1) << "initial call function domain:" << std::endl + VLOG(2) << "initial call function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and implied domain:" << std::endl << domains_->ToString(implied_domain) << std::endl @@ -423,21 +414,18 @@ class DeviceAnalyzer : public ExprVisitor { << PrettyPrint(call); // The above must match. - try { - domains_->Unify(func_domain, implied_domain); // higher-order - } catch (const Error& e) { + if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:" << std::endl << PrettyPrint(call) << std::endl - << "with function devices:" << std::endl + << "with function scopes:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and implied call devices:" << std::endl - << domains_->ToString(implied_domain) << std::endl - << e.what(); + << "and implied call scopes:" << std::endl + << domains_->ToString(implied_domain); } - VLOG(1) << "final call function domain:" << std::endl + VLOG(2) << "final call function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "for call:" << std::endl << PrettyPrint(call); @@ -477,7 +465,7 @@ class DeviceAnalyzer : public ExprVisitor { domains_->UnifyExprExact(function_node->body, func_domain->function_result()); // may be higher-order - VLOG(1) << "initial function domain:" << std::endl + VLOG(2) << "initial function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and function body domain:" << std::endl << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl @@ -492,37 +480,33 @@ class DeviceAnalyzer : public ExprVisitor { VisitExpr(function_node->params[i]); } - // If the function already has device attributes then we can further constrain the + // If the function already has SEScope attributes then we can further constrain the // function's domain to match them. - if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + if (!GetFunctionResultSEScope(function_node)->IsFullyUnconstrained()) { std::vector args_and_result; for (size_t i = 0; i < function_node->params.size(); ++i) { - args_and_result.emplace_back( - domains_->ForDeviceType(function_node->params[i]->checked_type(), - GetFunctionParamDeviceType(function_node, i))); + args_and_result.emplace_back(domains_->ForSEScope( + function_node->params[i]->checked_type(), GetFunctionParamSEScope(function_node, i))); } - args_and_result.emplace_back(domains_->ForDeviceType( - function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); - auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); - try { - domains_->Unify(func_domain, annotation_domain); // higher-order - } catch (const Error& e) { + args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), + GetFunctionResultSEScope(function_node))); + auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); + if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. LOG(FATAL) - << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:" << std::endl << PrettyPrint(function) << std::endl - << "with function devices:" << std::endl + << "with function scopes:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and annotation devices:" << std::endl - << domains_->ToString(annotation_domain) << std::endl - << e.what(); + << "and annotation scopes:" << std::endl + << domains_->ToString(annotation_domain); } } VisitExpr(function_node->body); - VLOG(1) << "final function domain:" << std::endl + VLOG(2) << "final function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and function body domain:" << std::endl << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl @@ -652,7 +636,7 @@ class DeviceAnalyzer : public ExprVisitor { * \code * def @main(%x, %y, %z) { * let %a = add(%x, %y); - * multiply(%a, on_device(%z, device_type=d)) + * multiply(%a, on_device(%z, se_scope=d)) * \endcode * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, * and the device for the function result, are still 'free'. The global 'default' device type @@ -664,17 +648,14 @@ class DeviceAnalyzer : public ExprVisitor { */ class DeviceDefaulter : public ExprVisitor { public: - DeviceDefaulter(IRModule mod, std::unique_ptr domains, - DLDeviceType default_device_type) - : mod_(std::move(mod)), - domains_(std::move(domains)), - default_device_type_(default_device_type) {} + DeviceDefaulter(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; - VLOG(0) << "using default device type " << default_device_type_; + VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope; for (const auto& pair : mod_->functions) { - VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; VisitExpr(pair.second); } return std::move(domains_); @@ -689,10 +670,11 @@ class DeviceDefaulter : public ExprVisitor { auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - if (domains_->AnyFree(func_domain)) { - VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); - domains_->SetResultDefaultThenParams(func_domain, default_device_type_); - VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + if (!domains_->IsFullyConstrained(func_domain)) { + VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, + domains_->config()->default_primitive_se_scope); + VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); } VisitExpr(function_node->body); } @@ -701,13 +683,14 @@ class DeviceDefaulter : public ExprVisitor { auto call = GetRef(call_node); auto func_domain = domains_->DomainForCallee(call); // higher-order ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); - if (domains_->AnyFree(func_domain)) { + if (!domains_->IsFullyConstrained(func_domain)) { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be // defaulted. - VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); - domains_->SetResultDefaultThenParams(func_domain, default_device_type_); - VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, + domains_->config()->default_primitive_se_scope); + VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); } return ExprVisitor::VisitExpr_(call_node); } @@ -719,13 +702,13 @@ class DeviceDefaulter : public ExprVisitor { Let let = Downcast(expr); // If the let-var device is still free force it to match the overall let. auto let_domain = domains_->DomainFor(let); // may be higher-order - DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); - ICHECK_NE(let_device_type, kInvalidDeviceType); + SEScope let_se_scope = domains_->ResultSEScope(let_domain); + ICHECK(!let_se_scope->IsFullyUnconstrained()); auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order - if (domains_->AnyFree(let_var_domain)) { - VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); - domains_->SetDefault(let_var_domain, let_device_type); - VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + if (!domains_->IsFullyConstrained(let_var_domain)) { + VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_se_scope); + VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); } VisitExpr(let->var); VisitExpr(let->value); @@ -738,8 +721,6 @@ class DeviceDefaulter : public ExprVisitor { IRModule mod_; /*! \brief The domains for all expressions. */ std::unique_ptr domains_; - /*! \brief The default device type. */ - DLDeviceType default_device_type_; }; /****** @@ -754,7 +735,7 @@ class DeviceDefaulter : public ExprVisitor { * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard * any existing "device_copy" CallNodes which are no-ops. * - * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * - Functions are given "param_se_scopes" and "result_se_scope" attributes to capture * the device type for its parameters and result. * * - Additional "device_copy" CallNodes are inserted wherever there's a transition between @@ -773,10 +754,10 @@ class DeviceDefaulter : public ExprVisitor { * * For example, we'll end up with programs that look like: * \code - * def @main(%x, %y, param_device_types=[...], result_device_type=...) { - * let %a = on_device(..., device_type=..., is_fixed=True) - * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), - * src_device_type=..., dst_device_type=...)) + * def @main(%x, %y, param_se_scopes=[...], result_se_scope=...) { + * let %a = on_device(..., se_scope=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., se_scope=..., is_fixed=True), + * src_se_scope=..., dst_se_scope=...)) * } * \endcode */ @@ -789,7 +770,7 @@ class DeviceCapturer : public ExprMutator { VLOG_CONTEXT << "CaptureDevices"; IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); for (const auto& pair : mod_->functions) { - VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; result->Add(pair.first, Downcast(Mutate(pair.second))); } return result; @@ -816,39 +797,39 @@ class DeviceCapturer : public ExprMutator { auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order - VLOG(1) << "capturing function:" << std::endl + VLOG(2) << "capturing function:" << std::endl << PrettyPrint(function) << std::endl << "with domain:" << std::endl << domains_->ToString(func_domain); // Gather the parameter and result device types for the function attributes. ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); - ICHECK_NE(result_device_type, kInvalidDeviceType); - Array param_device_types; - param_device_types.reserve(function_node->params.size()); + SEScope result_se_scope = domains_->ResultSEScope(func_domain); + ICHECK(!result_se_scope->IsFullyUnconstrained()); + Array param_se_scopes; + param_se_scopes.reserve(function_node->params.size()); for (size_t i = 0; i < function_node->params.size(); ++i) { - DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); - ICHECK_NE(param_device_type, kInvalidDeviceType); - param_device_types.push_back(param_device_type); + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + ICHECK(!param_se_scope->IsFullyUnconstrained()); + param_se_scopes.push_back(param_se_scope); } // Rewrite the body. Note that the body may have begun with an "on_device" so // be prepared to insert a "device_copy". Expr body = VisitChild( - /*lexical_device_type=*/result_device_type, - /*expected_device_type=*/result_device_type, - /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + /*lexical_se_scope=*/result_se_scope, + /*expected_se_scope=*/result_se_scope, + /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); // TODO(mbs): Avoid copy Function func = Function(function_node->params, body, function_node->ret_type, function_node->type_params, function_node->attrs, function_node->span); - return FunctionOnDevice(func, param_device_types, result_device_type); + return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope)); } Expr VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); - DLDeviceType call_device_type = GetDeviceType(call); + SEScope call_se_scope = GetSEScope(call); auto on_device_props = GetOnDeviceProps(call_node); if (on_device_props.body.defined()) { @@ -857,31 +838,36 @@ class DeviceCapturer : public ExprMutator { return VisitExpr(on_device_props.body); } - auto device_copy_props = GetDeviceCopyProps(call_node); + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); if (device_copy_props.body.defined()) { - DLDeviceType src_device_type = device_copy_props.src_dev_type; - ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); - if (call_device_type == src_device_type) { + SEScope src_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.src_se_scope); + SEScope dst_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.dst_se_scope); + ICHECK_EQ(call_se_scope, dst_se_scope); + if (src_se_scope == dst_se_scope) { // We can pinch out existing "device_copy" CallNodes if their source and destinations // match. return VisitExpr(device_copy_props.body); + } else { + return VisitChild(/*lexical_se_scope=*/dst_se_scope, + /*expected_se_scope=*/dst_se_scope, + /*child_se_scope=*/src_se_scope, device_copy_props.body); } - // else: handle as for any other call. } + // Generic call. auto func_domain = domains_->DomainForCallee(call); // higher-order - VLOG(1) << "considering call:" << std::endl + VLOG(2) << "considering call:" << std::endl << PrettyPrint(call) << std::endl - << "on device " << call_device_type << " with function domain:" << std::endl + << "in scope " << call_se_scope << " with function domain:" << std::endl << domains_->ToString(func_domain); - DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); - ICHECK_NE(result_device_type, kInvalidDeviceType); + SEScope result_se_scope = domains_->ResultSEScope(func_domain); + ICHECK(!result_se_scope->IsFullyUnconstrained()); // The callee is on the current device. Expr op = VisitChild( - /*lexical_device_type=*/call_device_type, - /*expected_device_type=*/call_device_type, - /*child_device_type=*/result_device_type, call_node->op); + /*lexical_se_scope=*/call_se_scope, + /*expected_se_scope=*/call_se_scope, + /*child_se_scope=*/result_se_scope, call_node->op); // Each argument can be on the device for the corresponding function parameter. However if // any of those differ from the overall call device then wrap them in an "on_device" to @@ -890,13 +876,13 @@ class DeviceCapturer : public ExprMutator { args.reserve(call_node->args.size()); ICHECK_EQ(func_domain->function_arity(), call->args.size()); for (size_t i = 0; i < call_node->args.size(); ++i) { - DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); - ICHECK_NE(param_device_type, kInvalidDeviceType) + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + ICHECK(!param_se_scope->IsFullyUnconstrained()) << "for parameter " << i << " for call:" << std::endl << PrettyPrint(call); - args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, - /*expected_device_type=*/param_device_type, - /*child_device_type=*/GetDeviceType(call_node->args[i]), + args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope, + /*expected_se_scope=*/param_se_scope, + /*child_se_scope=*/GetSEScope(call_node->args[i]), call_node->args[i])); } // TODO(mbs): Avoid copy @@ -907,27 +893,27 @@ class DeviceCapturer : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { Expr expr = GetRef(let_node); // Iterate through chained lets, provided they all agree on their device type. - DLDeviceType let_device_type = GetDeviceType(expr); + SEScope let_se_scope = GetSEScope(expr); std::vector> bindings; while (const auto* inner_let_node = expr.as()) { Expr inner_let = GetRef(inner_let_node); - if (GetDeviceType(inner_let) != let_device_type) { + if (GetSEScope(inner_let) != let_se_scope) { // We have a device transition which needs to be handled. break; } // The let-bound value can be on a different device than the overall let. However if those // devices don't agree wrap the let-bound value in an "on_device" to help downstream // transforms track devices lexically. - Expr value = VisitChild(/*lexical_device_type=*/let_device_type, - /*expected_device_type=*/GetDeviceType(inner_let_node->var), - /*child_device_type=*/GetDeviceType(inner_let_node->value), - inner_let_node->value); + Expr value = + VisitChild(/*lexical_se_scope=*/let_se_scope, + /*expected_se_scope=*/GetSEScope(inner_let_node->var), + /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); expr = inner_let_node->body; } - Expr body = VisitChild(/*lexical_device_type=*/let_device_type, - /*expected_device_type=*/let_device_type, - /*child_device_type=*/GetDeviceType(expr), expr); + Expr body = VisitChild(/*lexical_se_scope=*/let_se_scope, + /*expected_se_scope=*/let_se_scope, + /*child_se_scope=*/GetSEScope(expr), expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, /*span=*/std::get<2>(*itr)); @@ -987,69 +973,69 @@ class DeviceCapturer : public ExprMutator { return Match(data, std::move(clauses), match_node->complete, match_node->span); } - DLDeviceType GetDeviceType(const Expr& expr) { + SEScope GetSEScope(const Expr& expr) { // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. - auto props = GetOnDeviceProps(expr); + OnDeviceProps props = GetOnDeviceProps(expr); Expr true_expr = props.body.defined() ? props.body : expr; ICHECK(domains_->contains(true_expr)); - // If expr is higher order we'll return only the result domain's device type. - DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); - ICHECK_NE(device_type, kInvalidDeviceType) - << "no device type was determined for expression:" << std::endl + // If expr is higher order we'll return only the result domain's SEScope. + SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr)); + ICHECK(!se_scope->IsFullyUnconstrained()) + << "no SEScope was determined for expression:" << std::endl << PrettyPrint(true_expr); - return device_type; + return std::move(se_scope); } /*! - * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type - * (as required by the expression context the \p child is in) and the \p lexical_device_type + * \brief Reconcile the \p child_se_scope for \p child with both the \p expected_se_scope + * (as required by the expression context the \p child is in) and the \p lexical_se_scope * (as a downstream transform would infer based only on lexically enclosing "on_device" - * CallNodes and function attributes.) Generally \p lexical_device_type and \p - * expected_device_type are the same by definition, but may differ in arguments to functions + * CallNodes and function attributes.) Generally \p lexical_se_scope and \p + * expected_se_scope are the same by definition, but may differ in arguments to functions * and let-bound expressions. * - * If \p child_device_type differs from \p expected_device_type, wrap it as: + * If \p child_se_scope differs from \p expected_se_scope, wrap it as: * \code - * device_copy(on_device(child', device_type=child_device_type), - * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * device_copy(on_device(child', se_scope=child_se_scope), + * src_dev_type=child_se_scope, dst_dev_type=expected_se_scope) * \endcode * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the * child. * - * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * If \p expected_se_scope differs from \p lexical_se_scope, then (also) wrap * the expression as: * \code - * on_device(..., device_type=expected_device_type) + * on_device(..., se_scope=expected_se_scope) * \endcode * * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped * by a "device_copy", even though those copies will generally all be to the same destination * device. */ - Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, - DLDeviceType child_device_type, const Expr& child) { - ICHECK_NE(lexical_device_type, kInvalidDeviceType); - ICHECK_NE(expected_device_type, kInvalidDeviceType); - if (child->IsInstance()) { - // Primitive operators don't need to be rewritten and can have a different domain for - // each call site. + Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, + const SEScope& child_se_scope, const Expr& child) { + ICHECK(!lexical_se_scope->IsFullyUnconstrained()); + ICHECK(!expected_se_scope->IsFullyUnconstrained()); + if (child->IsInstance() || child->IsInstance()) { + // Primitive operators and contructors don't need to be rewritten and can have a + // different domain at each call site. return child; } Expr result = VisitExpr(child); - if (child_device_type != expected_device_type) { - VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type - << " to device type " << expected_device_type << " for:" << std::endl + if (child_se_scope != expected_se_scope) { + VLOG(2) << "creating " << DeviceCopyOp()->name << " from scope " << child_se_scope + << " to scope " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); // Also wrap the child in an "on_device" so downstream transforms can track devices // lexically. - result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); - result = DeviceCopy(result, child_device_type, expected_device_type); + result = MaybeOnDevice(result, child_se_scope, /*is_fixed=*/true); + result = DeviceCopy(result, child_se_scope, expected_se_scope); } - if (expected_device_type != lexical_device_type) { - VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + if (expected_se_scope != lexical_se_scope) { + VLOG(2) << "creating " << OnDeviceOp()->name << " for scope " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); - result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + result = MaybeOnDevice(result, expected_se_scope, /*is_fixed=*/true); } return result; } @@ -1059,9 +1045,9 @@ class DeviceCapturer : public ExprMutator { * is expected to be on the same device as the \p parent. */ Expr VisitChild(const Expr& parent, const Expr& child) { - DLDeviceType expected_device_type = GetDeviceType(parent); - DLDeviceType child_device_type = GetDeviceType(child); - return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + SEScope expected_se_scope = GetSEScope(parent); + SEScope child_se_scope = GetSEScope(child); + return VisitChild(expected_se_scope, expected_se_scope, child_se_scope, child); } /*! \brief Module we are rewriting, so we can lookup global variables. */ @@ -1079,21 +1065,22 @@ tvm::transform::Pass Rewrite() { } /*! \brief Run the remaining phases. */ -tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { +tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { return tvm::transform::CreateModulePass( - [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + [config = std::move(config)](IRModule mod, + tvm::transform::PassContext pass_cnxt) -> IRModule { // Collect the system of constraints for every sub-expression using existing "on_device" // and "device_copy" calls. - std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); - VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + std::unique_ptr domains = DeviceAnalyzer(mod, config).Analyze(); + VLOG(3) << "Domains after analysis:" << std::endl << domains->ToString(); // Choose sensible default devices for every sub-expression if otherwise unconstrained // by existing "on_device" or "device_copy" calls. - domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); - VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + domains = DeviceDefaulter(mod, std::move(domains)).Default(); + VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString(); // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture - // the above map, and attach additional "param_device_types" and "result_device_type" + // the above map, and attach additional "param_se_scopes" and "result_se_scope" // attributes to all function definitions. return DeviceCapturer(mod, std::move(domains)).Capture(); }, @@ -1107,17 +1094,14 @@ tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { *******/ // This function is declared in the public . -TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { +tvm::transform::Pass PlanDevices(CompilationConfig config) { std::vector passes; passes.emplace_back(Rewrite()); - passes.emplace_back(PlanDevicesCore(default_device_type)); - return tvm::transform::Sequential(std::move(passes), "PlanDevices"); + passes.emplace_back(PlanDevicesCore(std::move(config))); + return tvm::transform::Sequential(passes, "PlanDevices"); } -TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") - .set_body_typed([](const Device& default_device) { - return PlanDevices(default_device.device_type); - }); +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices").set_body_typed(PlanDevices); } // namespace transform } // namespace relay diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index c48a9b30967c6..05ee9d5ad5921 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -31,8 +31,7 @@ #include #include -#include "../op/annotation/annotation.h" -#include "./device_aware_visitors.h" +#include "../op/memory/on_device.h" #include "./pattern_utils.h" namespace tvm { @@ -42,7 +41,7 @@ namespace transform { namespace { /*! * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" - * annotation CallNode (which serves only to associate a device to the constant and has no + * annotation CallNode (which serves only to associate an \p SEScope to the constant and has no * operational effect). */ bool IsSimpleConstant(const Expr& expr) { @@ -87,19 +86,19 @@ class ConstantFolder : public MixedModeMutator { // the variable. // // We need to retain any "on_device" annotation so that downstream 'device aware' - // passes can still retrieve the device for the constant in its new position(s). Eg: - // def @f(..., result_device_type=D) { - // let %x = on_device(... something we eval to a constant..., device_type=E) + // passes can still retrieve the \p SEScope for the constant in its new position(s). Eg: + // def @f(..., result_se_scope=D) { + // let %x = on_device(... something we eval to a constant..., se_scope=E) // @f(..., %x, ...) // } - // Here the default device is D, whereas the argument %x to @f is on E (and @f expects + // Here the default scope is D, whereas the argument %x to @f is on E (and @f expects // that). No on_device annotation is required in the call according to the convention used // by the device-aware visitors. // // However once we've inlined the constant we need to insert an on_device, again to // respect the convention used by the device-aware visitors. - // def @f(..., result_device_type=D) { - // @f(..., on_device(...the constant..., device_type=E), ...) + // def @f(..., result_se_scope=D) { + // @f(..., on_device(...the constant..., se_scope=E), ...) // } VLOG(1) << "Replacing let-binding for " << op->var->name_hint() << " with constant:" << std::endl @@ -215,8 +214,8 @@ class ConstantFolder : public MixedModeMutator { Expr result = tuple_node->fields[tuple_get_item_node->index]; OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { - // (on_device((x, y, z), device_type=D).1 ==> on_device(y, device_type=D) - return MaybeOnDevice(result, props.device_type, props.is_fixed); + // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D) + return MaybeOnDevice(result, props.se_scope, props.is_fixed); } else { return result; } @@ -248,19 +247,15 @@ class ConstantFolder : public MixedModeMutator { VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr); // We'll invoke the interpreter using the generic CPU device and target. Technically there's - // no guarantee the results we bitwise equal what we'd get on the true device, however to + // no guarantee the results will be bitwise equal what we'd get on the true device, however to // support cross-compilation we don't want to assume the true device is available. - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; - Target target = Target("llvm"); // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); - Expr result = - ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); + Expr result = ObjectToExpr( + Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); return result; } @@ -288,17 +283,14 @@ class ConstantFolder : public MixedModeMutator { } // Get the constant shape - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; runtime::NDArray value; DLDataType cdtype = DataType::Int(32); if (ishape.empty()) { - value = runtime::NDArray::Empty({}, cdtype, dev); + value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); } else { ICHECK_NE(ishape.size(), 0); std::vector cshape = {static_cast(ishape.size())}; - value = runtime::NDArray::Empty(cshape, cdtype, dev); + value = runtime::NDArray::Empty(cshape, cdtype, eval_cpu_dev_); auto* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { @@ -313,7 +305,7 @@ class ConstantFolder : public MixedModeMutator { Constant shape = Downcast(ObjectToExpr(value)); if (shape->data.Shape().empty() && GetScalarFromConstant(shape) == 0) { - auto ndarray = runtime::NDArray::Empty({}, cdtype, dev); + auto ndarray = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); shape = Constant(ndarray); } @@ -342,12 +334,9 @@ class ConstantFolder : public MixedModeMutator { } // Get the constant size - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; runtime::NDArray value; DLDataType cdtype = DataType::Int(32); - value = runtime::NDArray::Empty({}, cdtype, dev); + value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); auto* data = static_cast(value->data); if (ishape.empty()) { *data = 0; @@ -390,6 +379,13 @@ class ConstantFolder : public MixedModeMutator { // Module IRModule module_; + // The kDLCPU device assumed to be available to the compiler. Used only when evaluating + // sub-expressions. + Device eval_cpu_dev_{kDLCPU, /*device_id=*/0}; + // The target for the above device assumed to be available to the compiler. Used only when + // evaluating sub-expressions. + Target eval_cpu_target_{"llvm"}; + // Cache the following ops for equivalence checking in this pass. const Op& device_copy_op_; const Op& shape_of_op_; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 81d704e2be8ed..a651a063d4182 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -57,17 +57,6 @@ using namespace tvm::runtime; namespace tvm { namespace relay { -inline Constant MakeConstant(const std::vector& value) { - return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); -} - -inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, - Array assert_shape, DLDeviceType offset_device_type) { - auto offset = - OnDevice(MakeConstantScalar(DataType::Int(64), 0), offset_device_type, /*is_fixed=*/true); - return AllocTensor(storage, offset, shape, dtype, assert_shape); -} - // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { if (const FunctionNode* func = expr.as()) { @@ -87,11 +76,13 @@ bool IsReshapeOnly(const Expr& expr) { class DialectRewriter : public transform::DeviceAwareExprMutator { public: - DialectRewriter(IRModule mod, const Target& target_host) - : transform::DeviceAwareExprMutator(std::move(mod)), target_host_(target_host) {} + DialectRewriter(IRModule mod, SEScope host_se_scope) + : transform::DeviceAwareExprMutator(std::move(mod)), + host_se_scope_(std::move(host_se_scope)) {} Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } + private: Expr VisitExpr_(const TupleNode* tn) final { LetList& scope = scopes_.back(); Array new_fields; @@ -130,7 +121,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const CallNode* cn) final { Call call = GetRef(cn); - DLDeviceType device_type = GetInScopeDeviceType(call); + SEScope se_scope = GetSEScope(call); if (IsPrimitive(cn)) { // Because we are in ANF we do not need to visit the arguments. // TODO(mbs): But does so anyway... @@ -162,26 +153,21 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } const DeviceCopyAttrs* copy_attr = attr.as(); CHECK(copy_attr); - return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type); + return DeviceCopy(new_args[0], copy_attr->src_se_scope, copy_attr->dst_se_scope); } else if (IsDynamic(ret_type)) { Function func = Downcast(cn->op); - // TODO(mbs): Device id is always zero. - Device device{device_type, /*device_id=*/0}; - return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, device); + return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, se_scope); } else { // Handle the static case Array outs; for (size_t i = 0; i < out_types.size(); ++i) { - DLDeviceType device_type = GetInScopeDeviceType(GetRef(cn)); - // TODO(mbs): Device id is always zero. - Device device{device_type, /*device_id=*/0}; - auto out = MakeStaticAllocation(&scope, out_types[i], device, std::to_string(i)); + auto out = MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i)); outs.push_back(out); } Tuple output(outs); // TODO(mbs): Capture device in attributes. Expr invoke = InvokeTVMOp(cn->op, ins, output); - scope.Push(OnDevice(invoke, device_type, /*is_fixed=*/true)); + scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(output->fields.begin(), output->fields.end())); } @@ -190,11 +176,26 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } } - private: + /*! Returns the Relay Constant representing the 1d tensor with \p value. + * + * CAUTION: Make sure the constant ends up on the correct device. + */ + inline Constant MakeConstant(const std::vector& value) { + return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); + } + + /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ + inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, + Array assert_shape) { + Expr offset = OnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, + /*is_fixed=*/true); + return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, + assert_shape); + } + // Insert a device copy node. - Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return Mutate(relay::DeviceCopy(inp, static_cast(src_dev), - static_cast(dst_dev))); + Expr DeviceCopy(const Expr& inp, SEScope src_se_scope, SEScope dst_se_scope) { + return Mutate(relay::DeviceCopy(inp, std::move(src_se_scope), std::move(dst_se_scope))); } // Check if a call invokes a primitive function. @@ -249,28 +250,28 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } // Allocate a tensor with a statically known shape. - Var MakeStaticAllocation(LetList* scope, const TensorType& type, Device dev, String name_hint) { + Var MakeStaticAllocation(LetList* scope, const TensorType& type, const SEScope& se_scope, + String name_hint) { std::vector int_shape; for (auto it : type->shape) { const auto* imm = it.as(); CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = OnDevice(MakeConstant(int_shape), cpu_device_.device_type, /*is_fixed=*/true); - Expr size = OnDevice(ComputeStorage(type), cpu_device_.device_type, /*is_fixed=*/true); + Expr shape = OnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); + Expr size = OnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = OnDevice(AllocStorage(size, alignment, dev, type->dtype), dev.device_type, + Expr value = OnDevice(AllocStorage(size, alignment, se_scope, type->dtype), se_scope, /*is_fixed=*/true); auto sto = scope->Push(var, value); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice( - AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape, cpu_device_.device_type), - dev.device_type, /*is_fixed=*/true); + auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), + se_scope, /*is_fixed=*/true); Var tensor_var("tensor_" + name_hint, Type(nullptr)); return scope->Push(tensor_var, tensor); } @@ -282,7 +283,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { tec::TECompiler compiler; - tec::CCacheKey key(func, target_host_); + tec::CCacheKey key(func, host_se_scope_->target); auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; @@ -310,10 +311,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { is_inputs.push_back(0); } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device - DLDeviceType device_type = GetInScopeDeviceType(arg); - if (device_type != cpu_device_.device_type) { - new_arg = OnDevice(DeviceCopy(new_arg, device_type, cpu_device_.device_type), - cpu_device_.device_type, /*is_fixed=*/true); + SEScope arg_se_scope = GetSEScope(arg); + if (arg_se_scope != host_se_scope_) { + new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_, + /*is_fixed=*/true); } Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); @@ -331,14 +332,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto tt = TensorType(out->shape, out->dtype); // Put shape func on CPU. This also ensures that everything between // shape_of and shape_func are on CPU. - auto alloc = OnDevice(MakeStaticAllocation(scope, tt, cpu_device_, std::to_string(i)), - cpu_device_.device_type, /*is_fixed=*/true); + auto alloc = OnDevice(MakeStaticAllocation(scope, tt, host_se_scope_, std::to_string(i)), + host_se_scope_, /*is_fixed=*/true); Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); alloc = scope->Push(shape_func_out_var, alloc); out_shapes.push_back(alloc); } auto shape_call = OnDevice(ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs), - cpu_device_.device_type, /*is_fixed=*/true); + host_se_scope_, /*is_fixed=*/true); Var shape_func_var("shape_func", Type(nullptr)); scope->Push(shape_func_var, shape_call); return out_shapes; @@ -347,19 +348,19 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { // Generate the code for invoking a TVM op with a dynamic shape. Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins, const std::vector& new_args, const std::vector& out_types, - const Type& ret_type, Device dev) { + const Type& ret_type, const SEScope& se_scope) { auto out_shapes = EmitShapeFunc(scope, func, new_args); std::vector storages; CHECK_EQ(out_shapes.size(), out_types.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), cpu_device_.device_type, + auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, /*is_fixed=*/true); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = OnDevice(AllocStorage(size, alignment, dev, out_type->dtype), dev.device_type, + auto val = OnDevice(AllocStorage(size, alignment, se_scope, out_type->dtype), se_scope, /*is_fixed=*/true); storages.push_back(scope->Push(sto_var, val)); } @@ -369,15 +370,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape, - cpu_device_.device_type), - dev.device_type, /*is_fixed=*/true); + auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), + se_scope, /*is_fixed=*/true); Var out_var("out_" + std::to_string(i), Type(nullptr)); outs.push_back(scope->Push(out_var, alloc)); } Tuple tuple_outs(outs); - auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), dev.device_type, /*is_fixed=*/true); + auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), se_scope, /*is_fixed=*/true); scope->Push(invoke); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); @@ -397,27 +397,24 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = OnDevice(MakeConstant(shape), cpu_device_.device_type, /*is_fixed=*/true); + shape_expr = OnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); } return ReshapeTensor(new_args[0], shape_expr, ret_ty->shape); } private: const Op& device_copy_op_ = Op::Get("device_copy"); + runtime::DataType compute_dtype_ = runtime::DataType::Int(64); + SEScope host_se_scope_; - Target target_host_; std::vector scopes_; - - runtime::DataType compute_dtype_ = runtime::DataType::Int(64); - Device cpu_device_{kDLCPU, 0}; }; namespace transform { -Pass ManifestAlloc(Target target_host, Map targets) { - CheckAndUpdateHostConsistency(&targets, &target_host); +Pass ManifestAlloc(SEScope host_se_scope) { return tvm::transform::CreateModulePass( - [=](IRModule mod, const PassContext& pass_ctx) { + [host_se_scope](IRModule mod, const PassContext& pass_ctx) { // We need to mutate module, therefore making a copy of it. mod.CopyOnWrite(); mod->ImportFromStd("core.rly"); @@ -427,7 +424,7 @@ Pass ManifestAlloc(Target target_host, Map targets) { for (const auto& it : glob_funcs) { if (auto* func_node = it.second.as()) { auto func = GetRef(func_node); - auto rewriter = DialectRewriter(mod, target_host); + auto rewriter = DialectRewriter(mod, host_se_scope); auto updated_func = rewriter.Rewrite(func); mod->Update(it.first, updated_func); @@ -440,11 +437,7 @@ Pass ManifestAlloc(Target target_host, Map targets) { 0, "ManifestAlloc", {}); } -TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc") - .set_body_typed([](Target target_host, Map targets) { - CheckAndUpdateHostConsistency(&targets, &target_host); - return ManifestAlloc(target_host, targets); - }); +TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc").set_body_typed(ManifestAlloc); } // namespace transform diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index fd7f0a5594c2a..317ac17f83c86 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -37,6 +37,7 @@ #include "../analysis/dependency_graph.h" #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" #include "./let_list.h" namespace tvm { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index c767770a8be8c..0814e73ab73d4 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -211,14 +211,14 @@ class Fill : ExprFunctor, private transform::Lexi } Expr Atomic(const Expr& e, const Var& v) { - Expr annotated_expr = MaybeOnDevice(e, GetInScopeDeviceType(e), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDevice(e, GetSEScope(e), /*is_fixed=*/true); return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; } // Bind expression `now` to var `v` if the original expression is in the include set, or if // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Expr annotated_expr = MaybeOnDevice(now, GetInScopeDeviceType(orig), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDevice(now, GetSEScope(orig), /*is_fixed=*/true); Var var = v.defined() ? v : Var(String("x"), Type()); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { @@ -229,15 +229,15 @@ class Fill : ExprFunctor, private transform::Lexi } Expr VisitExpr_(const CallNode* c, const Var& v) final { - auto props = GetOnDeviceProps(c); + OnDeviceProps props = GetOnDeviceProps(c); if (props.body.defined() && props.is_fixed) { // Keep track of expression device type for lexically enclosing sub-expressions. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); Expr body = VisitExpr(props.body, v); // We are done with this sub-expression. - PopDeviceType(); + PopSEScope(); // Preserve the "on_device" annotations. - return OnDevice(body, props.device_type, props.is_fixed); + return OnDevice(body, props.se_scope, props.is_fixed); } Expr e = GetRef(c); @@ -292,9 +292,9 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. - PushDeviceType(GetFunctionResultDeviceType(f)); + PushSEScope(GetFunctionResultSEScope(f)); for (size_t i = 0; i < f->params.size(); ++i) { - PushBoundVar(f->params[i], GetFunctionParamDeviceType(f, i)); + PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i)); } EnterFunctionBody(); ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, @@ -304,7 +304,7 @@ class Fill : ExprFunctor, private transform::Lexi for (size_t i = 0; i < f->params.size(); ++i) { PopBoundVar(f->params[i]); } - PopDeviceType(); + PopSEScope(); } if (function_nesting() == 0) { ICHECK(!v.defined()); @@ -319,7 +319,7 @@ class Fill : ExprFunctor, private transform::Lexi Expr VisitExpr_(const LetNode* l, const Var& v) final { Expr e = GetRef(l); // Keep track of bound variable device types for lexically enclosing sub-expressions. - PushBoundVar(l->var, GetInScopeDeviceType(l->value)); + PushBoundVar(l->var, GetSEScope(l->value)); VisitExpr(l->value, l->var); Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); // We are done with these sub-expressions. diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index 09b928fa1e392..f83e27d2c11d8 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -119,13 +120,10 @@ Instruction::Instruction(const Instruction& instr) { this->shape_of.tensor = instr.shape_of.tensor; return; case Opcode::ReshapeTensor: - this->reshape_tensor.tensor = instr.reshape_tensor.tensor; - this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + this->reshape_tensor = instr.reshape_tensor; return; case Opcode::DeviceCopy: - this->src = instr.src; - this->src_device_type = instr.src_device_type; - this->dst_device_type = instr.dst_device_type; + this->device_copy = instr.device_copy; return; default: std::ostringstream out; @@ -225,13 +223,10 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->shape_of.tensor = instr.shape_of.tensor; return *this; case Opcode::ReshapeTensor: - this->reshape_tensor.tensor = instr.reshape_tensor.tensor; - this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + this->reshape_tensor = instr.reshape_tensor; return *this; case Opcode::DeviceCopy: - this->src = instr.src; - this->src_device_type = instr.src_device_type; - this->dst_device_type = instr.dst_device_type; + this->device_copy = instr.device_copy; return *this; default: std::ostringstream out; @@ -338,14 +333,14 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName } Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - Index device_type, RegName dst) { + Index device_index, RegName dst) { Instruction instr; instr.op = Opcode::AllocStorage; instr.dst = dst; instr.alloc_storage.allocation_size = size; instr.alloc_storage.alignment = alignment; instr.alloc_storage.dtype_hint = dtype_hint; - instr.alloc_storage.device_type = device_type; + instr.alloc_storage.device_index = device_index; return instr; } @@ -366,14 +361,14 @@ Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName return instr; } -Instruction Instruction::DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, +Instruction Instruction::DeviceCopy(RegName src, Index src_device_index, Index dst_device_index, RegName dst) { Instruction instr; instr.op = Opcode::DeviceCopy; instr.dst = dst; - instr.src = src; - instr.src_device_type = src_device_type; - instr.dst_device_type = dst_device_type; + instr.device_copy.src = src; + instr.device_copy.src_device_index = src_device_index; + instr.device_copy.dst_device_index = dst_device_index; return instr; } @@ -609,7 +604,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " " << instr.alloc_storage.alignment << " " << DLDataType2String(instr.alloc_storage.dtype_hint) << " " - << instr.alloc_storage.device_type; + << instr.alloc_storage.device_index; break; } case Opcode::ShapeOf: { @@ -622,8 +617,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::DeviceCopy: { - os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.dst_device_type << " " - << instr.src_device_type; + os << "device_copy $" << instr.dst << " $" << instr.device_copy.src << " " + << instr.device_copy.dst_device_index << " " << instr.device_copy.src_device_index; break; } default: diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 4d7ee457e1e66..4a044584dccd8 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -63,6 +63,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetBytecode(); }); } else if (name == "get_constants") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); }); + } else if (name == "get_virtual_devices") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { @@ -165,13 +167,21 @@ String ShapeString(const ShapeTuple& shape_tuple, DLDataType dtype) { std::string Executable::GetConstants() const { std::ostringstream oss; - for (size_t i = 0; i < constants.size(); ++i) { const auto& constant = constants[i]; auto ndarray = Downcast(constant); - DLDeviceType device_type = static_cast(const_device_type[i]); oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) - << " on device of type " << device_type << std::endl; + << " on device index " << const_device_indexes[i] << std::endl; + } + return oss.str(); +} + +std::string Executable::GetVirtualDevices() const { + std::ostringstream oss; + for (size_t i = 0; i < virtual_devices.size(); ++i) { + const auto& device = virtual_devices[i]; + oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type << " and id " + << device.device_id << std::endl; } return oss.str(); } @@ -245,6 +255,9 @@ TVMByteArray Executable::Save() { // Save header SaveHeader(&strm); + // Save virtual devices section. + SaveVirtualDevicesSection(&strm); + // Global section. SaveGlobalSection(&strm); @@ -263,6 +276,11 @@ TVMByteArray Executable::Save() { return arr; } +void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { + strm->Write(virtual_devices); + strm->Write(host_device_index); +} + void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector> globals(this->global_map.begin(), this->global_map.end()); @@ -289,8 +307,8 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { runtime::SaveDLTensor(strm, it); } - // Save the const to device mapping. - strm->Write(this->const_device_type); + // Save the const to device index mapping. + strm->Write(this->const_device_indexes); } void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { @@ -338,7 +356,7 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { VMInstructionSerializer SerializeInstruction(const Instruction& instr) { std::vector fields; // Save the opcode. - VLOG(1) << "Serializing: " << instr << std::endl; + VLOG(2) << "Serializing: " << instr << std::endl; switch (instr.op) { case Opcode::Move: { // Number of fields = 2 @@ -407,7 +425,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(dtype.code); fields.push_back(dtype.bits); fields.push_back(dtype.lanes); - fields.push_back(instr.alloc_storage.device_type); + fields.push_back(instr.alloc_storage.device_index); fields.push_back(instr.dst); break; } @@ -487,7 +505,8 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { } case Opcode::DeviceCopy: { // Number of fields = 4 - fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst}); + fields.assign({instr.device_copy.src, instr.device_copy.src_device_index, + instr.device_copy.dst_device_index, instr.dst}); break; } default: @@ -504,7 +523,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), - func.params, func.params_device_type); + func.params, func.param_device_indexes); func_format.Save(strm); // Serialize each instruction. @@ -564,6 +583,9 @@ runtime::Module Executable::Load(const std::string& code, const runtime::Module // Load header. LoadHeader(&strm); + // Virtual devices section + exec->LoadVirtualDevicesSection(&strm); + // Global section. exec->LoadGlobalSection(&strm); @@ -579,6 +601,12 @@ runtime::Module Executable::Load(const std::string& code, const runtime::Module return runtime::Module(exec); } +void Executable::LoadVirtualDevicesSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&virtual_devices), "virtual_device"); + STREAM_CHECK(strm->Read(&host_device_index), "virtual_device"); + ICHECK(host_device_index >= 0 && host_device_index < static_cast(virtual_devices.size())); +} + void Executable::LoadGlobalSection(dmlc::Stream* strm) { std::vector globals; STREAM_CHECK(strm->Read(&globals), "global"); @@ -597,14 +625,15 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; i++) { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm), "constant"); - this->constants.push_back(constant); + this->constants.emplace_back(std::move(constant)); } - // Load the const to device mapping. - std::vector const_device_type; - STREAM_CHECK(strm->Read(&const_device_type), "constant"); - ICHECK_EQ(size, const_device_type.size()); - this->const_device_type = const_device_type; + // Load the const to device index mapping. + std::vector const_device_indexes; + const_device_indexes.reserve(size); + STREAM_CHECK(strm->Read(&const_device_indexes), "constant"); + ICHECK_EQ(size, const_device_indexes.size()); + this->const_device_indexes = std::move(const_device_indexes); } void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { @@ -846,8 +875,9 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, - loaded_func.register_file_size, loaded_func.params_device_type); + VMFunction vm_func = + VMFunction(loaded_func.name, loaded_func.params, instructions, + loaded_func.register_file_size, loaded_func.param_device_indexes); auto it = this->global_map.find(loaded_func.name); ICHECK(it != this->global_map.end()); ICHECK_LE(it->second, this->global_map.size()); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index cd2d1332580b4..e5afb0e4b1fcc 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -101,12 +101,10 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { void VirtualMachineDebug::OpStartHook(Instruction instr) { if (prof_ && prof_.operator*().IsRunning()) { if (instr.op == Opcode::LoadConst) { - Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + Device dev = GetDevice(exec_->const_device_indexes[instr.const_index]); prof_.operator*().StartCall("VM::LoadConst", dev, {}); } else if (instr.op == Opcode::DeviceCopy) { - Device dst_dev; - dst_dev.device_type = static_cast(instr.dst_device_type); - dst_dev.device_id = 0; + Device dst_dev = GetDevice(instr.device_copy.dst_device_index); prof_.operator*().StartCall("VM::DeviceCopy", dst_dev, {}); } else if (instr.op == Opcode::ReshapeTensor) { prof_.operator*().StartCall("VM::ReshapeTensor", devices_[1], {}); @@ -124,7 +122,7 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) { } else if (instr.op == Opcode::AllocTensorReg) { auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(shape_obj).CopyTo(cpu_dev); prof_.operator*().StartCall( @@ -135,8 +133,8 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); std::ostringstream shape; shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]"; - prof_.operator*().StartCall("VM::AllocStorage", - {static_cast(instr.alloc_storage.device_type), 0}, + Device dev = GetDevice(instr.alloc_storage.device_index); + prof_.operator*().StartCall("VM::AllocStorage", dev, {{"VM::Argument Shapes", String(shape.str())}}); } else { prof_.operator*().StartCall("VM::UnknownOp", devices_[1], {}); diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h index b4a10806caaf5..04a79c9b0210d 100644 --- a/src/runtime/vm/serialize_utils.h +++ b/src/runtime/vm/serialize_utils.h @@ -58,19 +58,19 @@ struct VMFunctionSerializer { size_t num_instructions; /*! \brief The parameters of the VMFunction. */ std::vector params; - /*! \brief The device type of each parameter of the VMFunction. */ - std::vector params_device_type; + /*! \brief The index for the devices holding each parameter of the VMFunction. */ + std::vector param_device_indexes; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params, - const std::vector& params_device_type) + const std::vector& param_device_indexes) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), params(params), - params_device_type(params_device_type) {} + param_device_indexes(param_device_indexes) {} /*! * \brief Load the serialized function header. @@ -87,7 +87,7 @@ struct VMFunctionSerializer { // Get the number of instructions. num_instructions = static_cast(std::stoll(func_info[2])); if (!strm->Read(¶ms)) return false; - if (!strm->Read(¶ms_device_type)) return false; + if (!strm->Read(¶m_device_indexes)) return false; return true; } @@ -102,7 +102,7 @@ struct VMFunctionSerializer { func_info.push_back(std::to_string(num_instructions)); strm->Write(func_info); strm->Write(params); - strm->Write(params_device_type); + strm->Write(param_device_indexes); } }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b903f793d799f..05adf1d69e8d6 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -232,12 +232,11 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { const auto& param_names = vm_func.params; ICHECK_EQ(args.size() - offset, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; - ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size()) << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - offset]; - Device dev = GetDevice(device_type); + Device dev = GetDevice(vm_func.param_device_indexes[i - offset]); if (args[i].type_code() == kTVMDLTensorHandle) { // Automatically convert input DLTensors to NDArray @@ -258,13 +257,14 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { inputs_.emplace(func_name, func_args); } -inline Device VirtualMachine::GetDevice(Index device_type) const { - ICHECK_GE(devices_.size(), device_type) << "devices_ doesn't contain device:" << device_type; +inline Device VirtualMachine::GetDevice(Index device_index) const { + ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index; + return devices_[device_index]; +} - auto dev = devices_[device_type]; - ICHECK_EQ(static_cast(dev.device_type), device_type) - << "device type " << device_type << " has not been initialized in the device list."; - return dev; +inline Allocator* VirtualMachine::GetAllocator(Index device_index) const { + ICHECK_GE(allocators_.size(), device_index) << "invalid device index: " << device_index; + return allocators_[device_index]; } void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { @@ -297,7 +297,12 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - VLOG(2) << "Executing Function: " << std::endl << func; + DLOG(INFO) << "Executing Function: " << std::endl << func; + for (int i = 0; i < static_cast(devices_.size()); ++i) { + DLOG(INFO) << "Device " << i << " has device type " << devices_[i].device_type + << " and device id " << devices_[i].device_id + << (i == exec_->host_device_index ? " (using as host device)" : ""); + } InvokeGlobal(func, args); RunLoop(); @@ -383,19 +388,31 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } -void VirtualMachine::Init(const std::vector& devs, +void VirtualMachine::Init(const std::vector& physical_devices, const std::vector& alloc_types) { - ICHECK_EQ(devs.size(), alloc_types.size()); - // Cache the device - for (size_t i = 0; i < devs.size(); i++) { - auto dev_type = static_cast(devs[i].device_type); - auto alloc = MemoryManager::GetOrCreateAllocator(devs[i], alloc_types[i]); - if (devices_.size() <= dev_type) { - devices_.resize(dev_type + 1); - allocators_.resize(dev_type + 1); - } - devices_[dev_type] = devs[i]; - allocators_[dev_type] = alloc; + ICHECK_EQ(physical_devices.size(), alloc_types.size()); + + // Find a physical device to represent each virtual device the VM code requires. + // (Recall the VM instructions refer to devices by "device index" into this vector of + // virtual devices.) + const size_t num_virtual_devices = exec_->virtual_devices.size(); + devices_.reserve(num_virtual_devices); + allocators_.reserve(num_virtual_devices); + + for (size_t device_index = 0; device_index < num_virtual_devices; ++device_index) { + // We'll retain the legacy behaviour and just match by device type. + // TODO(mbs): Generalize. + DLDeviceType virtual_device_type = exec_->virtual_devices[device_index].device_type; + auto itr = std::find_if(physical_devices.begin(), physical_devices.end(), + [virtual_device_type](const Device& physical_device) { + return physical_device.device_type == virtual_device_type; + }); + CHECK(itr != physical_devices.end()) + << "Unable to find a physical device (from among the " << physical_devices.size() + << " given) to match the virtual device with device type " << virtual_device_type; + const size_t i = std::distance(physical_devices.begin(), itr); + devices_.push_back(*itr); + allocators_.push_back(MemoryManager::GetOrCreateAllocator(*itr, alloc_types[i])); } } @@ -408,7 +425,7 @@ ObjectRef VirtualMachine::ReadRegister(Index r) const { return frames_.back().re int64_t VirtualMachine::LoadScalarInt(Index r) const { int64_t result = 0; const auto& obj = ReadRegister(r); - NDArray array = Downcast(CopyTo(obj, {kDLCPU, 0})); + NDArray array = Downcast(CopyTo(obj, GetDevice(exec_->host_device_index))); switch (array->dtype.bits) { case 1: { @@ -473,7 +490,7 @@ void VirtualMachine::RunLoop() { } if (!const_pool_[instr.const_index].defined()) { - Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + Device dev = GetDevice(exec_->const_device_indexes[instr.const_index]); const_pool_[instr.const_index] = CopyTo(constant_obj, dev); } WriteRegister(instr.dst, const_pool_[instr.const_index]); @@ -484,7 +501,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::LoadConsti: { - auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); + auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, GetDevice(exec_->host_device_index)); reinterpret_cast(tensor->data)[0] = instr.load_consti.val; WriteRegister(instr.dst, tensor); pc_++; @@ -544,7 +561,7 @@ void VirtualMachine::RunLoop() { auto object = ReadRegister(instr.get_tag.object); const auto& adt = Downcast(object); auto tag = adt.tag(); - auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); + auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, GetDevice(exec_->host_device_index)); reinterpret_cast(tag_tensor->data)[0] = tag; WriteRegister(instr.dst, tag_tensor); pc_++; @@ -600,7 +617,7 @@ void VirtualMachine::RunLoop() { } case Opcode::AllocTensorReg: { OpStartHook(instr); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_dev)); auto shape = ToShape(shape_tensor); @@ -637,16 +654,15 @@ void VirtualMachine::RunLoop() { OpStartHook(instr); auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; + auto storage_obj = SimpleObjAllocator().make_object(); - auto dev_type = instr.alloc_storage.device_type; - ICHECK_LT(static_cast(dev_type), allocators_.size()) - << "Memory allocator for device " << dev_type << " has not been initialized"; - auto* alloc = allocators_[dev_type]; - ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + Allocator* allocator = GetAllocator(instr.alloc_storage.device_index); + ICHECK(allocator) << "Did you forget to init the VirtualMachine with devices?"; VLOG(2) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) - << ", device_type=" << instr.alloc_storage.device_type; - storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); + << ", device_index=" << instr.alloc_storage.device_index; + + storage_obj->buffer = allocator->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); OpStopHook(); @@ -657,7 +673,8 @@ void VirtualMachine::RunLoop() { auto input = ReadRegister(instr.shape_of.tensor); NDArray input_array = Downcast(input); int ndim = input_array->ndim; - auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0}); + auto out_tensor = + NDArray::Empty({ndim}, {kDLInt, 64, 1}, GetDevice(exec_->host_device_index)); for (int i = 0; i < ndim; ++i) { reinterpret_cast(out_tensor->data)[i] = input_array->shape[i]; } @@ -682,7 +699,7 @@ void VirtualMachine::RunLoop() { } case Opcode::ReshapeTensor: { OpStartHook(instr); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); // Read the shape from shape tensor @@ -703,14 +720,13 @@ void VirtualMachine::RunLoop() { } case Opcode::DeviceCopy: { OpStartHook(instr); - auto tensor_src = ReadRegister(instr.src); + auto tensor_src = ReadRegister(instr.device_copy.src); NDArray src_data = Downcast(tensor_src); - Device src_dev = src_data->device; - ICHECK_EQ(static_cast(src_dev.device_type), instr.src_device_type); - - Device dst_dev; - dst_dev.device_type = static_cast(instr.dst_device_type); - dst_dev.device_id = 0; + Device actual_src_dev = src_data->device; + Device inst_src_dev = GetDevice(instr.device_copy.src_device_index); + ICHECK_EQ(actual_src_dev.device_type, inst_src_dev.device_type); + ICHECK_EQ(actual_src_dev.device_id, inst_src_dev.device_id); + Device dst_dev = GetDevice(instr.device_copy.dst_device_index); NDArray dst_data = src_data.CopyTo(dst_dev); WriteRegister(instr.dst, dst_data); diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc index 8f263c3b3273b..5df7984d003a9 100644 --- a/tests/cpp/relay/transforms/device_domains_test.cc +++ b/tests/cpp/relay/transforms/device_domains_test.cc @@ -45,24 +45,32 @@ IRModule TestModule() { } TEST(DeviceDomains, SmokeTest) { - DeviceDomains domains; + SEScope cpu = SEScope::ForDeviceType(kDLCPU); + SEScope cuda = SEScope::ForDeviceType(kDLCUDA); + TargetMap target_map; + target_map.Set(Integer(static_cast(kDLCPU)), Target("llvm")); + target_map.Set(Integer(static_cast(kDLCUDA)), Target("cuda")); + transform::PassContext ctxt = transform::PassContext::Create(); + CompilationConfig config(ctxt, target_map, /*optional_host_target=*/{}); + DeviceDomains domains(config); IRModule mod = TestModule(); Function f = Downcast(mod->Lookup("f")); DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); - DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + DeviceDomainPtr result_domain = domains.Free(f->ret_type); std::vector arg_and_results; arg_and_results.push_back(x_domain); arg_and_results.push_back(y_domain); arg_and_results.push_back(result_domain); - DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); - domains.Unify(actual_add_domain, implied_add_domain); - domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + DeviceDomainPtr implied_add_domain = domains.MakeHigherOrderDomain(std::move(arg_and_results)); + EXPECT_FALSE(domains.UnifyOrNull(actual_add_domain, implied_add_domain) == nullptr); + EXPECT_FALSE(domains.UnifyOrNull( + x_domain, domains.ForSEScope(f->params[0]->checked_type(), cuda)) == nullptr); - EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); - EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); + EXPECT_EQ(domains.ResultSEScope(y_domain), config->CanonicalSEScope(cuda)); + EXPECT_EQ(domains.ResultSEScope(result_domain), config->CanonicalSEScope(cuda)); } } // namespace diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 58e559eb96809..8ba91976523a1 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -26,14 +26,17 @@ def test_on_device_via_string(): assert isinstance(call, relay.Call) assert len(call.args) == 1 assert call.args[0] == x - assert call.attrs.device_type == 2 # ie kDLCUDA + assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.se_scope.virtual_device_id == 0 + assert call.attrs.se_scope.target is None + assert call.attrs.se_scope.memory_scope == "" assert not call.attrs.is_fixed def test_on_device_via_device(): x = relay.Var("x") - call = relay.annotation.on_device(x, tvm.device("llvm")) - assert call.attrs.device_type == 1 # ie kDLCPU + call = relay.annotation.on_device(x, tvm.device("cpu")) + assert call.attrs.se_scope.device_type_int == 1 # ie kDLCPU def test_on_device_invalid_device(): @@ -44,7 +47,7 @@ def test_on_device_invalid_device(): def test_on_device_is_fixed(): x = relay.Var("x") call = relay.annotation.on_device(x, "cuda", True) - assert call.attrs.device_type == 2 + assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA assert call.attrs.is_fixed @@ -54,15 +57,13 @@ def test_function_on_device(): f = relay.Function([x, y], relay.add(x, y)) func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") assert isinstance(func, relay.Function) - assert len(func.attrs["param_device_types"]) == 2 - assert func.attrs["param_device_types"][0] == 1 # ie kDLCPU - assert func.attrs["param_device_types"][1] == 2 # ie kDLCUDA - assert func.attrs["result_device_type"] == 2 # ie KDLCUDA + assert len(func.attrs["param_se_scopes"]) == 2 + assert func.attrs["param_se_scopes"][0].device_type_int == 1 # ie kDLCPU + assert func.attrs["param_se_scopes"][1].device_type_int == 2 # ie kDLCUDA + assert func.attrs["result_se_scope"].device_type_int == 2 # ie KDLCUDA if __name__ == "__main__": - test_on_device_via_string() - test_on_device_via_device() - test_on_device_invalid_device() - test_on_device_is_fixed() - test_function_on_device() + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/op/test_tensor.py b/tests/python/relay/op/test_tensor.py new file mode 100644 index 0000000000000..4d2c1766972ab --- /dev/null +++ b/tests/python/relay/op/test_tensor.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for tensor helpers.""" +import tvm +from tvm import relay +import pytest + + +def test_device_copy_via_string(): + x = relay.var("x") + call = relay.op.device_copy(x, "cuda", "cpu") + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.src_se_scope.virtual_device_id == 0 + assert call.attrs.src_se_scope.target is None + assert call.attrs.src_se_scope.memory_scope == "" + assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU + assert call.attrs.dst_se_scope.virtual_device_id == 0 + assert call.attrs.dst_se_scope.target is None + assert call.attrs.dst_se_scope.memory_scope == "" + + +def test_device_copy_via_device(): + x = relay.var("x") + call = relay.op.device_copy(x, tvm.device("cuda"), tvm.device("cpu")) + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index e3218ab1a8299..37eb1a2d6456f 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -26,18 +26,36 @@ import tvm.testing import numpy as np -CPU = tvm.device("cpu") # device_type=1 -GPU = tvm.device("cuda") # device_type=2 +HOST_DEVICE = tvm.device("cpu") +HOST_TARGET = tvm.target.Target("llvm") + +CPU_DEVICE = tvm.device("cpu") +CPU_TARGET = tvm.target.Target("llvm").with_host(HOST_TARGET) + +GPU_DEVICE = tvm.device("cuda") +GPU_TARGET = tvm.target.Target("cuda").with_host(HOST_TARGET) + +TARGETS = { + tvm.tir.IntImm("int32", CPU_DEVICE.device_type): CPU_TARGET, + tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET, +} + +HOST = tvm.target.make_se_scope(HOST_DEVICE, HOST_TARGET) # device_type=1 +CPU = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET) # device_type=1 +GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2 DEFAULT = GPU +CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) + core = tvm.IRModule() core.import_from_std("core.rly") def rewrite_and_assert(in_mod, expected_mod): """Manually run the pass and assert it's structurally equals to the expected.""" + config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) actual_mod = relay.transform.InferType()(in_mod) - actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.PlanDevices(config)(actual_mod) actual_mod = relay.transform.InferType()(actual_mod) expected_mod = relay.transform.InferType()(expected_mod) if not tvm.ir.structural_equal(actual_mod, expected_mod, True): @@ -59,7 +77,9 @@ def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): print("Not evaluating since GPU is not available") return with tvm.transform.PassContext(opt_level=3): - compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + compiled = relay.create_executor( + "vm", mod=in_mod, device=GPU_DEVICE, target=GPU_TARGET + ).evaluate() actual = compiled(*args).numpy() expected = reference_func(*args) tvm.testing.assert_allclose(actual, expected) @@ -85,9 +105,11 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a def test_plain(): + metatable = {"SEScope": [CPU, GPU]} + # Everything defaults to GPU def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], @@ -96,21 +118,28 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[2, 2, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -120,35 +149,44 @@ def ref(a, b, c, d): def test_left_add_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Force some args to be on CPU, rest default to GPU. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -158,35 +196,44 @@ def ref(a, b, c, d): def test_left_add_on_cpu_via_copy(): + metatable = {"SEScope": [CPU, GPU]} + # As for test_left_add_on_cpu, but with an explicit device_copy. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %1 = device_copy(%0, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -196,37 +243,46 @@ def ref(a, b, c, d): def test_both_adds_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = add(%c, %d); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=1); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][0]); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %2 = add(%c, %d); - %3 = on_device(%2, device_type=1, is_fixed=True); - %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%4, %5) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -236,34 +292,42 @@ def ref(a, b, c, d): def test_sharing(): + metatable = {"SEScope": [CPU, GPU]} + # The same add sub-expression is annotated twice. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1); - %2 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); + %2 = on_device(%0, se_scope=meta[SEScope][0]); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = on_device(%0, device_type=1, is_fixed=True); - %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%3, %4) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b): @@ -274,35 +338,44 @@ def ref(a, b): def test_let_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # The device for a let-bound expression can flow from uses of the let-bound var. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %l = add(%a, %b); let %r = add(%c, %d); - %0 = on_device(%l, device_type=1); + %0 = on_device(%l, se_scope=meta[SEScope][0]); subtract(%0, %r) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - let %l = on_device(%0, device_type=1, is_fixed=True); + let %l = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); let %r = add(%c, %d); - %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%1, %r) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -312,39 +385,49 @@ def ref(a, b, c, d): def test_func_param_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Devices for function parameters flow to call sites. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %f = fn (%x, %y) { %0 = add(%x, %y); - on_device(%0, device_type=1) + on_device(%0, se_scope=meta[SEScope][0]) }; %1 = %f(%a, %b); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1, 1], result_device_type=1) { - let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { + let %f = fn (%x, %y, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x, %y) }; %0 = %f(%a, %b); %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -354,9 +437,11 @@ def ref(a, b, c, d): def test_func_result_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Devices for call sites flow to function results. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], @@ -365,30 +450,38 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], add(%x, %y) }; %0 = %f(%a, %b); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { - let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { + let %f = fn (%x, %y, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x, %y) }; %1 = %f(%a, %b); - %2 = on_device(%1, device_type=1, is_fixed=True); - %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = add(%c, %d); subtract(%3, %4) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -398,15 +491,17 @@ def ref(a, b, c, d): def test_higher_order(): + metatable = {"SEScope": [CPU, GPU]} + # The constraint on %a flows back to %y via %f and %h def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%g) { fn (%a) { - %0 = on_device(%a, device_type=1); + %0 = on_device(%a, se_scope=meta[SEScope][0]); %1 = %g(%0); add(%1, %x) } @@ -418,30 +513,36 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %3 = %2(%y); subtract(%x, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) { - let %f = fn (%g, param_device_types=[2], result_device_type=2) { - fn (%a, param_device_types=[1], result_device_type=2) { - %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + let %f = fn (%g, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + %0 = device_copy(%a, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %1 = %g(%0); add(%1, %x) } }; - let %h = fn (%b, param_device_types=[2], result_device_type=2) { + let %h = fn (%b, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { negative(%b) }; %2 = %f(%h); %3 = %2(%y); subtract(%x, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -457,14 +558,16 @@ def h(b): def test_function_in_tuple(): + metatable = {"SEScope": [CPU, GPU]} + # Since %f ends up in a tuple its argument and result is forced to be on the CPU def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { - %0 = on_device(%b, device_type=1); + %0 = on_device(%b, se_scope=meta[SEScope][0]); add(%a, %0) }; let %t = (%f, %x); @@ -472,17 +575,20 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %2 = %t.0; %2(%1, %y) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %b) }; let %t = (%f, %x); @@ -490,7 +596,10 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %1 = %t.0; %1(%0, %y) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -501,14 +610,14 @@ def ref(x, y): def test_device_copy(): const = rand((5, 7)) - metatable = {"relay.Constant": [relay.const(const)]} + metatable = {"SEScope": [CPU, GPU], "relay.Constant": [relay.const(const)]} def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { - %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%0, meta[relay.Constant][0]) } """, @@ -521,8 +630,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { - %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + def @main(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%0, meta[relay.Constant][0]) } """, @@ -538,31 +648,37 @@ def ref(x): def test_shape_func(): + metatable = {"SEScope": [HOST, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { %0 = fn (%y: Tensor[(?), float32]) { nn.relu(%y) }; - let %p = on_device(%0, device_type=2, is_fixed=True); - %1 = on_device(%x, device_type=2, is_fixed=True); + let %p = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); %2 = vm.shape_of(%1, dtype="int64"); %3 = (%2,); %4 = (%s,); vm.shape_func(%p, %3, %4, is_input=[False]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], - param_device_types=[2, 1], result_device_type=1) { - let %p = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + let %p = fn (%y: Tensor[(?), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { nn.relu(%y) }; %1 = vm.shape_of(%x, dtype="int64"); @@ -570,7 +686,10 @@ def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], %3 = (%s,); vm.shape_func(%p, %2, %3, is_input=[False]) } - """ + """, + "from_string", + None, + metatable, ) # Don't try to execute, too fiddly to setup. @@ -578,28 +697,37 @@ def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], def test_shape_of(): + metatable = {"SEScope": [HOST, GPU]} + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the # result defaults to the result device for @main which is the CPU, thus forcing a copy. # TODO(mbs): Perhaps the defaulting heuristics are being too clever? def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32]) { - %0 = on_device(%x, device_type=2, is_fixed=True); + %0 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); vm.shape_of(%0, dtype="int64") } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + def @main(%x: Tensor[(?, ?), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { vm.shape_of(%x, dtype="int64") } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -609,28 +737,33 @@ def ref(x): def test_alloc_storage(): + metatable = {"SEScope": [HOST, GPU]} + def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%size: int64, %alignment: int64) { - memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) } """, "from_string", core, + metatable, ) def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { - memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + def @main(%size: int64, %alignment: int64, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) } """, "from_string", core, + metatable, ) # Don't try to execute, too fiddly to setup. @@ -639,7 +772,7 @@ def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_dev def test_alloc_tensor(): shape = np.array([3, 2]) - metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(shape, dtype="int64")]} def input(): return tvm.parser.parse( @@ -659,9 +792,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { - %0 = on_device(0, device_type=1, is_fixed=True); - %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + %0 = on_device(0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) } """, @@ -676,7 +809,7 @@ def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { def test_reshape_tensor(): newshape = [2, 4, 2] - metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(newshape, dtype="int64")]} def input(): return tvm.parser.parse( @@ -695,8 +828,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { - %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + def @main(%x: Tensor[(2, 8), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } """, @@ -712,26 +846,34 @@ def ref(x): def test_dynamic_input(): + metatable = {"SEScope": [GPU]} + # There's nothing special about inferring devices for partially unknown types. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { add(%x0, %x1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], - param_device_types=[2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x0, %x1) } - """ + """, + "from_string", + None, + metatable, ) def ref(x0, x1): @@ -741,35 +883,44 @@ def ref(x0, x1): def test_redundant_annotation(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = subtract(%1, %z); - %3 = on_device(%0, device_type=1); + %3 = on_device(%0, se_scope=meta[SEScope][0]); add(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %3 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %4 = subtract(%2, %z); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%4, %5) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -780,31 +931,40 @@ def ref(x, y, z): def test_annotate_expr(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][1]); %2 = subtract(%1, %z); - on_device(%2, device_type=1) + on_device(%2, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[2, 2, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2, is_fixed=True); - %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -814,17 +974,22 @@ def ref(x, y, z): def test_annotate_all(): + metatable = {"SEScope": [CPU, GPU]} + def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = subtract(%1, %z); - on_device(%2, device_type=1) + on_device(%2, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): @@ -832,11 +997,15 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); subtract(%0, %z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -858,43 +1027,52 @@ def test_conv_network(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=1); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][0]); %4 = add(%2, %3); - %5 = on_device(%4, device_type=2); + %5 = on_device(%4, se_scope=meta[SEScope][1]); %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - on_device(%6, device_type=1) + on_device(%6, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], - %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %weight: Tensor[(64, 64, 3, 3), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %1 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %3 = on_device(%2, device_type=1, is_fixed=True); - %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = add(%4, %5); - %7 = on_device(%6, device_type=2, is_fixed=True); - %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + %7 = on_device(%6, se_scope=meta[SEScope][1], is_fixed=True); + %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) } - """ + """, + "from_string", + None, + metatable, ) # Don't try to execute, we don't have a reference conv2d @@ -902,40 +1080,49 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 def test_tuple_get_item(): + metatable = {"SEScope": [CPU, GPU]} + # Note that the device copy should be placed after projection rather than before. This is handled by # a heuristic in the pass. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(3, 3, 4), float32]) { let %t = split(%x, indices_or_sections=3); - %0 = on_device(%t, device_type=1); - %1 = on_device(%t, device_type=1); + %0 = on_device(%t, se_scope=meta[SEScope][0]); + %1 = on_device(%t, se_scope=meta[SEScope][0]); %2 = %0.0; %3 = %1.1; %4 = subtract(%2, %3); - on_device(%4, device_type=2) + on_device(%4, se_scope=meta[SEScope][1]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + def @main(%x: Tensor[(3, 3, 4), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = split(%x, indices_or_sections=3); - let %t = on_device(%0, device_type=1, is_fixed=True); + let %t = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %1 = %t.0; - %2 = on_device(%1, device_type=1, is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); %3 = %t.1; - %4 = on_device(%3, device_type=1, is_fixed=True); - %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); - %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%5, %6) } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -959,45 +1146,53 @@ def test_propogation(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { %0 = negative(%x); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = negative(%1); - %3 = on_device(%0, device_type=1); + %3 = on_device(%0, se_scope=meta[SEScope][0]); %4 = negative(%3); - %5 = on_device(%2, device_type=2); - %6 = on_device(%4, device_type=2); + %5 = on_device(%2, se_scope=meta[SEScope][1]); + %6 = on_device(%4, se_scope=meta[SEScope][1]); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2); + %8 = on_device(%7, se_scope=meta[SEScope][1]); %9 = negative(%8); - on_device(%9, device_type=1) + on_device(%9, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + def @main(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = negative(%x); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %3 = on_device(%0, device_type=1, is_fixed=True); - %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = negative(%2); %6 = negative(%4); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2, is_fixed=True); - %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -1023,43 +1218,51 @@ def test_fusible_network(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][1]); %2 = negative(%1); - %3 = on_device(%2, device_type=1); + %3 = on_device(%2, se_scope=meta[SEScope][0]); %4 = negative(%0); %5 = add(%3, %4); - %6 = on_device(%5, device_type=2); + %6 = on_device(%5, se_scope=meta[SEScope][1]); %7 = negative(%6); - on_device(%7, device_type=1) + on_device(%7, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2, is_fixed=True); - %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); %3 = negative(%2); - %4 = on_device(%3, device_type=1, is_fixed=True); - %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = negative(%0); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2, is_fixed=True); - %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1083,37 +1286,45 @@ def test_unpropagatable_graph(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = multiply(%c, %d); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=2); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][1]); %4 = subtract(%2, %3); - on_device(%4, device_type=1) + on_device(%4, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][0]) { %0 = multiply(%c, %d); - %1 = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); %2 = add(%a, %b); - %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -1123,14 +1334,16 @@ def ref(a, b, c, d): def test_conditional(): + metatable = {"SEScope": [CPU, GPU]} + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { let %f = fn (%a) { - %0 = on_device(%y, device_type=1, is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); add(%a, %0) }; let %g = fn (%a1) { @@ -1143,19 +1356,23 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { }; %h(%z) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1], result_device_type=1) { - let %f = fn (%a, param_device_types=[1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { + let %f = fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %y) }; - let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { subtract(%a1, %y) }; let %h = if (%x) { @@ -1165,7 +1382,10 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], }; %h(%z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -1182,36 +1402,46 @@ def g(a): def test_global(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { - %0 = on_device(%b, device_type=1); + %0 = on_device(%b, se_scope=meta[SEScope][0]); add(%a, %0) } def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { @f(%y, %x) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { - %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], + result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%a, %0) } def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + param_se_scopes=[meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { @f(%y, %x) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1224,33 +1454,41 @@ def f(a, b): def test_ref(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %r = ref(%x); - %0 = on_device(%y, device_type=1); + %0 = on_device(%y, se_scope=meta[SEScope][0]); ref_write(%r, %0); %1 = ref_read(%r); add(%x, %1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { let %r = ref(%x); - %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); ref_write(%r, %0); %1 = ref_read(%r); add(%x, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1263,8 +1501,10 @@ def ref(x, y): def test_adt(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] type List[A] { @@ -1272,7 +1512,7 @@ def input(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { - %0 = on_device(%y, device_type=1, is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); %1 = Nil; %2 = Cons(%0, %1); let %l = Cons(%x, %2); @@ -1280,11 +1520,14 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { Cons(%z, _) => %z } } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] type List[A] { @@ -1292,7 +1535,7 @@ def expected(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = Nil; %1 = Cons(%y, %0); let %l = Cons(%x, %1); @@ -1300,7 +1543,10 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Cons(%z, _) => %z } } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 79979747dfd8a..0961278b64020 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -853,11 +853,11 @@ def check_remote(server): # Get a handle to remote Executable. rexec = remote.load_module("vm_library.so") - ctx = remote.cpu() + device = remote.cpu() # Build a VM out of the executable and context. - vm_factory = runtime.vm.VirtualMachine(rexec, ctx) + vm_factory = runtime.vm.VirtualMachine(rexec, device) np_input = np.random.uniform(size=(10, 1)).astype("float32") - input_tensor = tvm.nd.array(np_input, ctx) + input_tensor = tvm.nd.array(np_input, device) # Invoke its "main" function. out = vm_factory.invoke("main", input_tensor) # Check the result. @@ -1003,7 +1003,10 @@ def test_shape_func_nested_function(): def test_storage_size_and_offset_on_cpu(): """Tests allocations place sizes and offsets on the CPU host even if the rest of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + # And/or move this to C++ and test the VM executable in it's C++ instead of + # pretty-printed form. # CPU = device type 1 # GPU = device type 2 @@ -1027,15 +1030,19 @@ def @main(%a: Tensor[(5, 7), float32], # - The size of the tensor's storage (first arg) to alloc_storage # - The offset of the tensor within the storage (second arg) to alloc_tensor # Both should be on the CPU - assert not "on device of type 2" in exe.constants - assert "on device of type 1" in exe.constants + assert "VirtualDevice[0]: device type 1" in exe.virtual_devices + assert "Constant[0]: has shape int64[] on device index 0" in exe.constants + assert "Constant[1]: has shape int64[] on device index 0" in exe.constants @tvm.testing.requires_cuda def test_reshape_shape_on_cpu(): """Tests the argument to a reshape places the shape on the CPU host even if the rest of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + # And/or move this to C++ and test the VM executable in it's C++ instead of + # pretty-printed form. # CPU = device type 1 # GPU = device type 2 @@ -1056,8 +1063,44 @@ def @main(%x: Tensor[(2, 8), float32], ) # The newshape annotation should have been turned into a constant on the CPU. - assert not "on device of type 2" in exe.constants - assert "on device of type 1" in exe.constants + assert "VirtualDevice[0]: device type 1" in exe.virtual_devices + assert "Constant[0]: has shape int64[3] on device index 0" in exe.constants + + +@tvm.testing.requires_cuda +def test_multi_targets(): + # Build an IRModule. + n = 10 + x = relay.var("x", shape=(n,)) + y = relay.var("y", shape=(n,)) + z = relay.var("z", shape=(n,)) + f = relay.Function([x, y, z], x + relay.op.annotation.on_device(y + z, tvm.cpu())) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + with tvm.transform.PassContext( + opt_level=3, config={"relay.fallback_device_type": tvm.cuda().device_type} + ): + exe = relay.vm.compile( + mod, target={"cpu": tvm.target.Target("llvm"), "cuda": tvm.target.Target("cuda")} + ) + + # Run + vm = runtime.vm.VirtualMachine(exe, [tvm.cuda(), tvm.cpu()]) + x_data = np.random.rand( + n, + ).astype("float32") + y_data = np.random.rand( + n, + ).astype("float32") + z_data = np.random.rand( + n, + ).astype("float32") + actual_result = vm.invoke("main", x_data, y_data, z_data) + + # Test + expected_result = x_data + y_data + z_data + tvm.testing.assert_allclose(actual_result.numpy(), expected_result) if __name__ == "__main__": diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 9eae3dd336727..04879573bd6a5 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -424,16 +424,7 @@ def foo(): if __name__ == "__main__": - test_record_split_reorder_fuse_annotation() - test_record_compute_at_root_inline_cache_read_write() - test_record_follow_split_follow_fused_split() - test_record_pragma_storage_align_rfactor() - test_recover_measure_input() - test_workload_dis_factor() - test_measure_local_builder_runner() - test_dag_measure_local_builder_runner() - test_workload_serialization() - test_measure_local_builder_rpc_runner() - test_measure_target_host() - test_measure_special_inputs_map_by_name_local_runner() - test_measure_special_inputs_map_by_name_rpc_runner() + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 92c1174e728c2..e2ce442e0d885 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -18,7 +18,6 @@ import datetime import json import os -import sys import tarfile import numpy @@ -411,4 +410,7 @@ def test_export_byoc_c_module(): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + import sys + + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_export_operator_model_library_format() diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index b67142b423588..4e777435429b0 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -196,4 +196,7 @@ def test_report_serialization(): if __name__ == "__main__": - test_papi("llvm", tvm.cpu()) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 75b61d281840f..0499a3e6c65a3 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -39,4 +39,7 @@ def test_basic(dev, target): if __name__ == "__main__": - test_basic(tvm.cpu(), tvm.target.Target("llvm")) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))