diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 5ee719f9964f..e466cde097ac 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -191,24 +191,24 @@ constexpr const char* kTarget = "target"; constexpr const char* kGlobalSymbol = "global_symbol"; /*! - * \brief The device type which will hold each of the functions parameters. + * \brief The SEScope which will hold each of the functions parameters. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: Array (but interpreted as Array) + * Type: Array */ -constexpr const char* kParamDeviceTypes = "param_device_types"; +constexpr const char* kParamSEScopes = "param_se_scopes"; /*! - * \brief The device type which will hold the function result. + * \brief The SEScope which will hold the function result. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: Integer (but interpreted as DLDeviceType) + * Type: SEScope */ -constexpr const char* kResultDeviceType = "result_device_type"; +constexpr const char* kResultSEScope = "result_se_scope"; } // namespace attr } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index f88ca8ef6380..79889ce9a790 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -31,68 +31,6 @@ namespace tvm { namespace relay { -/*! - * \brief Attributes for the "on_device" special operator. - * - * The Relay call (aka 'annotation'): - * \code - * on_device(sub_expr, device_type=2) - * \endcode - * constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2 - * (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be - * executed and stored on a different device. If so the compiler will automatically insert a - * "device_copy" call to mediate the transition between devices. - * - * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: - * \code - * multiply(on_device(add(%x, %y), device_type=2), %z) - * \endcode - * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. - * The compiler will rewrite this to: - * \code - * multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z) - * \endcode - * - * The Relay call - * \code - * on_device(sub_expr, device_type=2, is_fixed=True) - * \endcode - * is similar to the above, however the annotation itself must appear in an expression on the - * same device. The compiler will check the devices are consistent, and will not insert any - * "device_copy" call. This form of annotation shouldn't be necessary in user programs. However - * it is needed by the \p PlanDevices pass to fully specify the results of device planning so that - * the pass is idempotent. - * - * E.g.: The following program is equivalent to the above: - * \code - * let %a = on_device(add(%x, %y), device_type=2, is_fixed=True) - * multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z) - * \endcode - * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored - * on the GPU. - */ -struct OnDeviceAttrs : public tvm::AttrsNode { - // TODO(mbs): Replace device types with TargetDevice. - /*! \brief Device type on which argument expression should be evaluated. */ - int device_type = kInvalidDeviceType; - /*! - * \brief If true, the result device must also be \p device_type and device planning should - * not insert any "device_copy" calls to respect this annotation. - * - * This is used by the device planning pass itself when annotating the planned program. - */ - bool is_fixed = false; - - TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { - TVM_ATTR_FIELD(device_type) - .describe("The type of the virtual device which should hold the expression result.") - .set_default(0); - TVM_ATTR_FIELD(is_fixed) - .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") - .set_default(false); - } -}; - /*! * \brief Annotate an expression to be cast into specific data type. */ diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index f7b0a04f45fa..6d97ab79be4a 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include +#include #include @@ -35,17 +36,14 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { - // TODO(mbs): Should be TargetDevice. - int dst_dev_type; - int src_dev_type; + SEScope src_se_scope = SEScope::FullyUnconstrained(); + SEScope dst_se_scope = SEScope::FullyUnconstrained(); TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { - TVM_ATTR_FIELD(src_dev_type) - .describe("The virtual device/context type where the op copies data from.") - .set_default(0); - TVM_ATTR_FIELD(dst_dev_type) - .describe("The virtual device/context type where the op copies data to.") - .set_default(0); + TVM_ATTR_FIELD(src_se_scope) + .describe("The (virtual) device and scope where the op copies data from."); + TVM_ATTR_FIELD(dst_se_scope) + .describe("The (virtual) device and scope where the op copies data to."); } }; diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index 85462c087cee..952d4affc584 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -42,15 +43,13 @@ Expr ToTupleType(const Type& t, const std::vector& exprs); */ struct AllocStorageAttrs : public tvm::AttrsNode { DataType dtype; - int device_id; - int device_type; + SEScope se_scope = SEScope::FullyUnconstrained(); TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) .describe("The dtype of the tensor to allocate.") .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); + TVM_ATTR_FIELD(se_scope).describe("The SEScope on which to allocate memory."); } }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h new file mode 100644 index 000000000000..405926e209c6 --- /dev/null +++ b/include/tvm/relay/attrs/on_device.h @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/on_device.h + * \brief Attribute for the on device annotation. + */ +#ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_ +#define TVM_RELAY_ATTRS_ON_DEVICE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Attributes for the "on_device" special operator. + * + * The Relay call (aka 'annotation'): + * \code + * on_device(sub_expr, se_scope=S) + * \endcode + * constrains \p sub_expr to execute and store its result on the \p SEScope \p S. + * However the annotation itself may appear in an expression to be executed and stored on a + * different \p SEScope. If so the compiler will automatically insert a "device_copy" call to + * mediate the transition between \p SEScopes. + * + * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: + * \code + * multiply(on_device(add(%x, %y), se_scope=GPU), %z) + * \endcode + * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. + * The compiler will rewrite this to: + * \code + * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z) + * \endcode + * + * The Relay call + * \code + * on_device(sub_expr, se_scope=S, is_fixed=True) + * \endcode + * is similar to the above, however the annotation itself must appear in an expression on the + * same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not + * insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs. + * However it is needed by the \p PlanDevices pass to fully specify the results of device planning + * so that the pass is idempotent. + * + * E.g.: The following program is equivalent to the above: + * \code + * let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True) + * multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z) + * \endcode + * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored + * on the GPU. + */ +struct OnDeviceAttrs : public tvm::AttrsNode { + /*! + * \brief (Virtual) \p SEScope on which the result of the argument expression should be stored. + */ + SEScope se_scope = SEScope::FullyUnconstrained(); + /*! + * \brief If true, the result \p SEScope must also be \p se_scope, and device planning should + * not insert any "device_copy" calls to respect this annotation. + * + * This is used by the device planning pass itself when annotating the planned program. + */ + bool is_fixed = false; + + TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { + TVM_ATTR_FIELD(se_scope) + .describe("The (virtual) device and scope holding the expression result.") + .set_default(SEScope::FullyUnconstrained()); + TVM_ATTR_FIELD(is_fixed) + .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + .set_default(false); + } +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ATTRS_ON_DEVICE_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index e740776d6d4f..aa9d3b41554c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -437,23 +439,27 @@ TVM_DLL Pass RelayToTIRTargetHook(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param target_host The target used by the host for compilation. - * \param targets The device type and target pairs for compilation. + * \param cpu_se_scope SEScope for computations and data which must reside on a CPU, such as + * shapes and shape functions. * * \return The pass. */ -TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +TVM_DLL Pass ManifestAlloc(SEScope cpu_se_scope); /*! - * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which - * every Relay sub-expression should run (and the result stored). Captures the result of that - * analysis using new "on_device" and "device_copy" CallNodes. See - * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p SEScope on which + * every Relay sub-expression should run and the result stored. Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. + * + * See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} * for help recovering the device for an arbitrary sub-expression in downstream transformations. * - * \param default_device_type DLDeviceType for default device. + * \param config Describes the targets and default \p SEScope for all primitive operators and + * host sub-expressions. + * + * \return The pass. */ -TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); +TVM_DLL Pass PlanDevices(CompilationConfig config); } // namespace transform diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 72a557fa93b1..a2a64d76ce86 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -176,6 +176,7 @@ struct Instruction { RegName object; } get_tag; struct /* AllocADT Operands */ { + // TODO(mbs): Needs a DeviceAndScope. /*! \brief The datatype's constructor tag. */ Index constructor_tag; /*! \brief The number of fields to store in the datatype. */ @@ -184,6 +185,7 @@ struct Instruction { RegName* datatype_fields; }; struct /* AllocClosure Operands */ { + // TODO(mbs): Needs a DeviceAndScope. /*! \brief The index into the function table. */ Index clo_index; /*! \brief The number of free variables to capture. */ @@ -198,8 +200,8 @@ struct Instruction { Index alignment; /*! \brief The hint of the dtype. */ DLDataType dtype_hint; - /*! \brief The device type of the allocation. */ - Index device_type; + /*! \brief The index of the device on which the allocation will be made. */ + Index device_index; } alloc_storage; struct /* ShapeOf Operands */ { RegName tensor; @@ -210,11 +212,11 @@ struct Instruction { } reshape_tensor; struct /* DeviceCopy Operands */ { RegName src; - /*! \brief The source device type. */ - Index src_device_type; - /*! \brief The destination device type. */ - Index dst_device_type; - }; + /*! \brief The index of the source device to copy from. */ + Index src_device_index; + /*! \brief The index of the destination deviceto copy to. */ + Index dst_device_index; + } device_copy; }; /*! @@ -352,12 +354,12 @@ struct Instruction { * \param size The size of the allocation. * \param alignment The allocation's alignment. * \param dtype_hint The data type hint for the allocator. - * \param device_type The device type for the allocator. + * \param device_index The index of the device to allocate on. * \param dst The destination to place the storage. * \return The alloc storage instruction. */ static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - Index device_type, RegName dst); + Index device_index, RegName dst); /*! * \brief Get the shape of an input tensor. * \param tensor The input tensor. @@ -376,12 +378,12 @@ struct Instruction { /*! * \brief Copy tensor cross different devices. * \param src The source register. - * \param src_device_type The device type of the tensor for the source register. - * \param dst_device_type The device type of the tensor ofr the destination register. + * \param src_device_index The index of the device holding the tensor in the source register. + * \param dst_device_index The index of the device to hold the tensor in the destination register. * \param dst The destination register to store the copied tensor. * \return The device copy instruction. */ - static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + static Instruction DeviceCopy(RegName src, Index src_device_index, Index dst_device_index, RegName dst); Instruction(); diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 6e564fd62380..311667904df6 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -132,12 +132,18 @@ class Executable : public ModuleNode { std::string GetBytecode() const; /*! - * \brief Returns a description of all the contants in the executable in human-readable + * \brief Returns a description of all the constants in the executable in human-readable * format. Not intended to be machine readable, but rather to help with debugging and * diffing generated code. */ std::string GetConstants() const; + /*! + * \brief Returns a description of all the (virtual) devices in the executable in human-readable + * format. + */ + std::string GetVirtualDevices() const; + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. @@ -183,6 +189,16 @@ class Executable : public ModuleNode { const char* type_key() const final { return "VMExecutable"; } + /*! + * \brief The (compile-time, virtual) devices corresponding to each device index. + * Currently we only support at most one device per device type. + */ + std::vector virtual_devices; + /*! + * \brief The device index corresponding to the 'host' device. That will hold and evaluate + * shape-related data and code. + */ + int host_device_index = -1; /*! \brief The global constant pool. */ std::vector constants; /*! \brief A map from globals (as strings) to their index in the function map. */ @@ -195,38 +211,52 @@ class Executable : public ModuleNode { std::map> op_attrs; /*! \brief The virtual machine's function table. */ std::vector functions; - /*! \brief The device type for each constant. */ - std::vector const_device_type; + /*! \brief The index of the device holding each constant. */ + std::vector const_device_indexes; private: + /*! + * \brief Save the virtual devices + * + * /param strm The output stream. + */ + void SaveVirtualDevicesSection(dmlc::Stream* strm); + /*! * \brief Save the globals. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveGlobalSection(dmlc::Stream* strm); /*! * \brief Save the constant pool. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveConstantSection(dmlc::Stream* strm); /*! * \brief Save primitive op names. * - * \param strm The input stream. + * \param strm The output stream. */ void SavePrimitiveOpNames(dmlc::Stream* strm); /*! * \brief Save the vm functions. * - * \param strm The input stream. + * \param strm The output stream. */ void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Load the virtual devices + * + * /param strm The input stream. + */ + void LoadVirtualDevicesSection(dmlc::Stream* strm); + /*! * \brief Load the globals. * diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index ece73fcfda34..604c97330d99 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -82,20 +82,20 @@ struct VMFunction { /*! \brief The instructions representing the function. */ std::vector instructions; /*! \brief The size of the frame for this function */ - Index register_file_size; - /*! \brief The device type of each parameter for this function. */ - std::vector params_device_type; - - VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, Index register_file_size, - const std::vector params_device_type = {}) - : name(name), - params(params), - instructions(instructions), + Index register_file_size = 0; + /*! \brief The indexes for the device holding each function parameter. */ + std::vector param_device_indexes; + + VMFunction(std::string name, std::vector params, + std::vector instructions, Index register_file_size, + std::vector param_device_indexes) + : name(std::move(name)), + params(std::move(params)), + instructions(std::move(instructions)), register_file_size(register_file_size), - params_device_type(params_device_type) {} + param_device_indexes(std::move(param_device_indexes)) {} - VMFunction() {} + VMFunction() = default; friend std::ostream& operator<<(std::ostream& os, const VMFunction&); }; @@ -239,17 +239,19 @@ class VirtualMachine : public runtime::ModuleNode { Index output_size, const std::vector& args); /*! - * \brief Initialize the virtual machine for a set of devices. - * \param devices The set of TVM devices. + * \brief Initialize the virtual machine for a set of (physical) devices. + * \param physical_devices The set of TVM devices. * \param alloc_types The allocator types for each device. */ - void Init(const std::vector& devices, const std::vector& alloc_types); + void Init(const std::vector& physical_devices, + const std::vector& alloc_types); /*! \brief Run VM dispatch loop. */ void RunLoop(); - /*! \brief Get device from the device list based on a given device type. */ - Device GetDevice(Index device_type) const; + /*! \brief Get device from the device list based on a given device index. */ + Device GetDevice(Index device_index) const; + Allocator* GetAllocator(Index device_index) const; /*! * \brief Invoke a global setting up the VM state to execute. @@ -301,9 +303,13 @@ class VirtualMachine : public runtime::ModuleNode { const Executable* exec_; /*! \brief The function name to inputs mapping. */ std::unordered_map> inputs_; - /*! \brief The set of TVM devices the VM is currently executing on. */ + /*! + * \brief The "physical" devices the VM can execute primitives on. All "device indexes" + * are w.r.t. this vector. Each entry in this vector must match the corresponding entry + * in the executable's "virtual" devices vector. + */ std::vector devices_; - /*! \brief The cached memory allocators. */ + /*! \brief The cached memory allocators, one per device. */ std::vector allocators_; /*! * \brief The constant pool for runtime. It caches the device dependent diff --git a/python/tvm/micro/contrib/stm32/emitter.py b/python/tvm/micro/contrib/stm32/emitter.py index 8453ea78e012..aec5912871fd 100644 --- a/python/tvm/micro/contrib/stm32/emitter.py +++ b/python/tvm/micro/contrib/stm32/emitter.py @@ -44,7 +44,7 @@ def _fix_name(node_name): - """ Replace ':' with '_' in names like 'InputImg:0' """ + """Replace ':' with '_' in names like 'InputImg:0'""" return node_name.replace(":", "_") @@ -116,7 +116,7 @@ def _get_tensor_size_bytes(dims, dltype): def _preprocess_code(src): - """ Hack the C code implementing the model. """ + """Hack the C code implementing the model.""" dst = "#include \n" "#include \n\n" dst = dst + src return dst @@ -193,7 +193,7 @@ def __init__(self, include_activations=True, include_inputs=True, include_output self._quantization = {} def _extract_quantization_info(self, quantization): - """ Build dictionary with quantization infos.""" + """Build dictionary with quantization infos.""" for dl_tensor_name in self._input_data: if dl_tensor_name in quantization: @@ -258,7 +258,7 @@ def _get_tensor_from_node(self, nid, idx): return tensor def _compute_data_placement(self): - """ Compute inputs, outputs, weight, activation sizes""" + """Compute inputs, outputs, weight, activation sizes""" self._inputs = self._arg_nodes.copy() @@ -548,7 +548,7 @@ def parse_module(self, module, quantization=None): self._parse_model(quantization) def _emit_params_data(self, name, out_h, out_c): - """ Emits the network_data[c,h] files with parameters.""" + """Emits the network_data[c,h] files with parameters.""" name_upper = name.upper() @@ -674,7 +674,7 @@ def _emit_open(self, name, out_h, out_c): ) def _emit_close(self, name, out_h, out_c): - """ Emits the ai_model_info structure. """ + """Emits the ai_model_info structure.""" name_upper = name.upper() @@ -794,7 +794,7 @@ def _emit_tensor_quant(self, dl_tensor_name, out_c): return None def _emit_tensor_init(self, dl_tensor_name, tensor, out_c): - """ Emits the tensor instantiation code. """ + """Emits the tensor instantiation code.""" dltype = tensor["dltype"] dims = tensor["dims"] @@ -838,7 +838,7 @@ def _emit_tensor_init(self, dl_tensor_name, tensor, out_c): def _emit_activation_buffers(self, name, out_c): # pylint: disable=unused-argument - """ Emits activation tensors, including inputs/outputs.""" + """Emits activation tensors, including inputs/outputs.""" out_c.write( textwrap.dedent( @@ -905,7 +905,7 @@ def _emit_activation_buffers(self, name, out_c): out_c.write(f"\n") def _emit_params_buffers(self, name, out_c): - """ Emits all parameter tensors.""" + """Emits all parameter tensors.""" out_c.write( textwrap.dedent( @@ -922,7 +922,7 @@ def _emit_params_buffers(self, name, out_c): out_c.write(f"\n") def _emit_network(self, name, out_c): - """ Emits prototypes for the network operator functions.""" + """Emits prototypes for the network operator functions.""" out_c.write( textwrap.dedent( @@ -967,7 +967,7 @@ def _emit_tensor_activation(self, dl_tensor_name, tensor, out_c): ) def _emit_activation_init(self, name, out_c): - """ Emits buffer initialization code for activation tensors.""" + """Emits buffer initialization code for activation tensors.""" out_c.write( textwrap.dedent( @@ -1015,7 +1015,7 @@ def _emit_activation_init(self, name, out_c): ) def _emit_params_init(self, name, out_c): - """ Emits buffer initialization code for params tensors.""" + """Emits buffer initialization code for params tensors.""" out_c.write( textwrap.dedent( @@ -1063,13 +1063,13 @@ def _emit_params_init(self, name, out_c): ) def _emit_init(self, name, out_c): - """ Emits buffer initialization code.""" + """Emits buffer initialization code.""" self._emit_activation_init(name, out_c) self._emit_params_init(name, out_c) def _emit_run(self, name, out_h, out_c): - """ Emits the run function code.""" + """Emits the run function code.""" out_h.write( textwrap.dedent( @@ -1230,7 +1230,7 @@ def _emit_run(self, name, out_h, out_c): out_c.write(f"\n") def _emit_create_destroy(self, name, out_h, out_c): - """ Emits the create/destroy functions.""" + """Emits the create/destroy functions.""" out_h.write( textwrap.dedent( @@ -1296,7 +1296,7 @@ def _emit_create_destroy(self, name, out_h, out_c): ) def emit_code(self, dest_dir, model_name): - """ Emits the C code implementing the model. """ + """Emits the C code implementing the model.""" # Build the directory structure if os.path.exists(dest_dir): diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index f5f8870ab015..cf70dc6e267e 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Annotation operations.""" +from tvm import target from tvm.runtime import ndarray as _nd from tvm.runtime import Device as _Device @@ -22,11 +23,11 @@ from .. import op as reg -def _device_to_int(device): +def _make_se_scope(device): if isinstance(device, _Device): - return device.device_type + return target.make_se_scope(device) if isinstance(device, str): - return _nd.device(device).device_type + return target.make_se_scope(_nd.device(device)) raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) @@ -39,7 +40,7 @@ def on_device(data, device, is_fixed=False): The expression to be annotated. device : Union[:py:class:`Device`, str] - The device to annotate with. Only the device's type is significant. + The device to annotate with. is_fixed : bool If false (the default), a device_copy @@ -52,7 +53,7 @@ def on_device(data, device, is_fixed=False): result : tvm.relay.Expr The annotated expression. """ - return _make.on_device(data, _device_to_int(device), is_fixed) + return _make.OnDevice(data, _make_se_scope(device), is_fixed) def function_on_device(function, param_devices, result_device): @@ -65,18 +66,18 @@ def function_on_device(function, param_devices, result_device): The function to be annotated. param_devices : Array[Union[:py:class:`Device`, str]] - The devices for each parameter. Only the device types are significant. + The devices for each parameter. result_device: Union[:py:class:`Device`, str] - The device for the function result. Only the device type is significant. + The device for the function result. Returns ------- - result : tvm.rleay.Function + result : tvm.relay.Function The annotated function. """ - return _make.function_on_device( - function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device) + return _make.FunctionOnDevice( + function, [_make_se_scope(d) for d in param_devices], _make_se_scope(result_device) ) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index e615bbf21b86..d9847a453569 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -16,6 +16,7 @@ # under the License. """Basic tensor operations.""" # pylint: disable=redefined-builtin, unused-argument +from tvm import target from tvm.runtime import ndarray as _nd from tvm.runtime import Device as _Device from tvm.te.hybrid import script @@ -26,6 +27,14 @@ from . import op as reg +def _make_se_scope(device): + if isinstance(device, _Device): + return target.make_se_scope(device) + if isinstance(device, str): + return target.make_se_scope(_nd.device(device)) + raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) + + # We create a wrapper function for each operator in the # python side to call into the positional _make.OpName function. # @@ -1181,7 +1190,7 @@ def copy_shape_func(attrs, inputs, _): return [_copy_shape_func(inputs[0])] -def device_copy(data, src_dev, dst_dev): +def device_copy(data, src_device, dst_device): """Copy data from the source device to the destination device. This operator helps data transferring between difference devices for heterogeneous execution. @@ -1191,10 +1200,10 @@ def device_copy(data, src_dev, dst_dev): data : tvm.relay.Expr The tensor to be copied. - src_dev : Union[:py:class:`Device`, str] + src_device : Union[:py:class:`Device`, str] The source device where the data is copied from. - dst_dev : Union[:py:class:`Device`, str] + dst_device : Union[:py:class:`Device`, str] The destination device where the data is copied to. Returns @@ -1202,26 +1211,7 @@ def device_copy(data, src_dev, dst_dev): result : tvm.relay.Expr The copied result. """ - if isinstance(src_dev, _Device): - src_dev = src_dev.device_type - elif isinstance(src_dev, str): - src_dev = _nd.device(src_dev).device_type - else: - raise ValueError( - "src_dev is expected to be the type of Device or " - "str, but received %s" % (type(src_dev)) - ) - - if isinstance(dst_dev, _Device): - dst_dev = dst_dev.device_type - elif isinstance(dst_dev, str): - dst_dev = _nd.device(dst_dev).device_type - else: - raise ValueError( - "dst_dev is expected to be the type of Device or " - "str, but received %s" % (type(dst_dev)) - ) - return _make.device_copy(data, src_dev, dst_dev) + return _make.DeviceCopy(data, _make_se_scope(src_device), _make_se_scope(dst_device)) def shape_of(data, dtype="int32"): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0dc07944836d..01473a82fb3a 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -891,6 +891,7 @@ def __init__(self, *args, **kwargs): # initialize handle in cass pass_cls creation failed.fg self.handle = None inst = pass_cls(*args, **kwargs) + # it is important not to capture self to # avoid a cyclic dependency def _pass_func(func, mod, ctx): @@ -1146,14 +1147,26 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() -def PlanDevices(default_device): +def PlanDevices(config): """ - Uses existing "on_device" and "device_copy" CallNodes to infer the device on which - every Relay sub-expression should run (and the result stored). Captures the result of that - analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of - the default_device is ignored. + Uses existing "on_device" and "device_copy" CallNodes to infer the SEScope on which + every Relay sub-expression should run and the result stored. Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Sub-expressions which are + not otherwise constrained are assigned to the default_primitive_se_scope. However data and + computations which must be hosted on a CPU (such as shapes and shape functions) use the + cpu_se_scope. + + Parameters + ---------- + config : tvm.CompilationConfig + The compilation configuration, specifying available targets and default devices. + + Returns + ------- + ret : tvm.transforms.Pass + The pass. """ - return _ffi_api.PlanDevices(default_device) + return _ffi_api.PlanDevices(config) def FoldExplicitPadding(): diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index c1cbc966acdc..365e38c6e06c 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -72,6 +72,7 @@ def __init__(self, mod): self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] self._get_constants = self.mod["get_constants"] + self._get_virtual_devices = self.mod["get_virtual_devices"] self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] @@ -251,6 +252,11 @@ def constants(self): Useful for debugging and diffing generated executables in unit tests.""" return self._get_constants() + @property + def virtual_devices(self): + """Returns a human-readable description of all the (virtual) devices in the executable.""" + return self._get_virtual_devices() + @property def globals(self): """Get the globals used by the Relay VM executable. @@ -295,7 +301,8 @@ class VirtualMachine(object): The VM executable. device : tvm.runtime.Device or List[tvm.runtime.Device] - The device to deploy the module + The device(s) on which the model will run. + Currently at most one device per device type is supported. memory_cfg : str or Dict[tvm.runtime.Device, str], optional Config the type of memory allocator. The allocator type can be ["naive", @@ -363,10 +370,7 @@ def _setup_device(self, dev, memory_cfg): devs = dev if not isinstance(dev, (list, tuple)): if not isinstance(dev, tvm.runtime.Device): - raise TypeError( - "dev is expected to be Device or \ - List[Device]" - ) + raise TypeError("dev is expected to be Device or List[Device]") devs = [dev] # CPU is required for executing shape functions diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 58bcccf90879..62bb7155a10e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -131,18 +131,18 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void VisitExpr_(const TupleNode* op) final { std::vector storage_ids; - std::vector device_types; + std::vector se_scopes; std::vector storage_sizes_in_bytes; Expr expr = GetRef(op); for (Expr field : op->fields) { auto sid = GetStorage(field); storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); - device_types.insert(device_types.end(), sid->device_types.begin(), sid->device_types.end()); + se_scopes.insert(se_scopes.end(), sid->se_scopes.begin(), sid->se_scopes.end()); storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), sid->storage_sizes_in_bytes.begin(), sid->storage_sizes_in_bytes.end()); } - storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + storage_device_map_[expr] = StorageInfo(storage_ids, se_scopes, storage_sizes_in_bytes); AssignReturnSid(expr); } @@ -151,7 +151,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { auto sids = GetStorage(op->tuple); ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); storage_device_map_[expr] = - StorageInfo({sids->storage_ids[op->index]}, {sids->device_types[op->index]}, + StorageInfo({sids->storage_ids[op->index]}, {sids->se_scopes[op->index]}, {sids->storage_sizes_in_bytes[op->index]}); AssignReturnSid(expr); } @@ -185,7 +185,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \param prototype The prototype token. * \return The required memory size. * - * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc + * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc */ size_t GetMemorySizeBytes(const TensorType& ttype) { size_t size = 1; @@ -217,24 +217,25 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { */ void CreateStorage(const ExprNode* op) { Expr expr = GetRef(op); - return CreateStorage(expr, GetInScopeDeviceType(expr)); + return CreateStorage(expr, GetSEScope(expr)); } /*! - * \brief Create storage to hold the result of evaluating \p expr on \p device_type. + * \brief Create storage to hold the result of evaluating \p expr in \p se_scope. */ - void CreateStorage(const Expr& expr, DLDeviceType device_type) { - ICHECK(device_type != kInvalidDeviceType) << "invalid device type for expr:" << std::endl + void CreateStorage(const Expr& expr, SEScope se_scope) { + ICHECK(!se_scope->IsFullyUnconstrained()) << "invalid SEScope for expr:" << std::endl << PrettyPrint(expr); std::vector storage_ids; - std::vector device_types; + std::vector se_scopes; std::vector storage_sizes_in_bytes; for (const auto& ttype : FlattenTupleType(expr->checked_type())) { storage_ids.push_back(next_available_sid_++); - device_types.push_back(device_type); + se_scopes.push_back(se_scope); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } - storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes); + storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(se_scopes), + std::move(storage_sizes_in_bytes)); } /*! \brief mapping of expression -> storageInfo */ @@ -622,7 +623,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { mod = WithAttr(mod, "main_func_info", func_info); } - IRModule lowered_mod = tec::LowerTEPass(targets_, mod_name, [this](Function func) { + IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 4dd12ad1d106..cd9c7d68366d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include @@ -161,6 +162,8 @@ std::unique_ptr MakeExecutorCodegen(String executor_str) { */ class RelayBuildModule : public runtime::ModuleNode { public: + RelayBuildModule() = default; + /*! * \brief Get member function to front-end * \param name The name of the function. @@ -207,7 +210,7 @@ class RelayBuildModule : public runtime::ModuleNode { } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2); - *rv = this->Optimize(args[0], args[1], this->params_); + *rv = this->Optimize(args[0], args[1]); }); } else { LOG(FATAL) << "Unknown packed function: " << name; @@ -274,26 +277,16 @@ class RelayBuildModule : public runtime::ModuleNode { * \brief Build relay IRModule for graph executor * * \param mod Relay IRModule - * \param target Target device + * \param targets Target devices * \param target_host Host target device */ void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { - for (const auto& pair : targets) { - VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); - } - if (target_host.defined()) { - VLOG(0) << "Build target_host = " << target_host->str(); - } - VLOG(0) << "Build executor = '" << executor << "'"; - VLOG(0) << "Build mod_name = '" << mod_name << "'"; - - // Create protected variable targets_ from ground up - targets_ = targets; - target_host_ = target_host; + VLOG_CONTEXT << "Build"; executor_ = executor; - CheckAndUpdateHostConsistency(&targets_, &target_host_); - BuildRelay(mod, params_, mod_name); + config_ = CompilationConfig(PassContext::Current(), targets, target_host); + + BuildRelay(std::move(mod), mod_name); } protected: @@ -302,95 +295,58 @@ class RelayBuildModule : public runtime::ModuleNode { * * \param relay_module The input IRModule where optmization will be applied on. * \param targets The device type to `Target` mapping. - * \param params The param name to value mapping. * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetMap& targets, - const std::unordered_map& params) { - targets_ = targets; - // No target_host setup it seems. - return OptimizeImpl(relay_module, params); + IRModule Optimize(IRModule relay_module, const TargetMap& targets) { + VLOG_CONTEXT << "Optimize"; + // TODO(mbs): executor_ will be whatever was left over from last Build. Note that + // the empty executor string will CHECK fail, so how are folks using this API? + config_ = CompilationConfig(transform::PassContext::Current(), targets, + /*optional_host_target=*/Target()); + return OptimizeImpl(std::move(relay_module)); } - IRModule OptimizeImpl(IRModule relay_module, - const std::unordered_map& params) { + IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (params.size()) { + if (!params_.empty()) { ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params); + auto new_main = BindParamsByName(main_func, params_); IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs = GetPassPrefix(targets_, false); + Array pass_seqs = GetPassPrefix( + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); - // TODO(mbs): Centralize this logic and reconcile with similar in relay/backend/vm/compiler.cc - DLDeviceType default_device_type; - if (targets_.size() == 1) { - // Homogenous execution. - default_device_type = static_cast((*targets_.begin()).first->value); - const auto& target = (*targets_.begin()).second; - - // This pass currently only supports the homogeneous case. - pass_seqs.push_back( - transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); - } else { - // Heterogeneous execution. - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type"); - if (opt_fallback_dev) { - default_device_type = static_cast(opt_fallback_dev.value()->value); - Integer integer(static_cast(default_device_type)); - CHECK_GT(default_device_type, 0U) - << "The 'relay.fallback_device_type' is set to an invalid device type."; - if (targets_.count(integer) == 0) { - LOG(WARNING) - << "The 'relay.fallback_device_type' has been set to " << default_device_type - << " however no target has been given for that device type in the targets map. " - "Creating an appropriate default target."; - targets_.Set(integer, CreateDefaultTarget(default_device_type)); - } - } else { - default_device_type = kDLCPU; - Integer integer(static_cast(default_device_type)); - if (targets_.count(integer) == 0) { - LOG(WARNING) << "Using the default device type of kDLCPU, however no target has been " - "given for that device type in the targets map. Creating an appropriate " - "default target."; - targets_.Set(integer, CreateDefaultTarget(default_device_type)); - } - } - } - // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // hetrogenous execution. - pass_seqs.push_back(transform::PlanDevices(default_device_type)); + pass_seqs.push_back(transform::PlanDevices(config_)); // Fuse the operations if it is needed. pass_seqs.push_back(transform::FuseOps()); // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); - if (targets_.size() == 1) { - With tctx((*targets_.begin()).second); + if (config_->optional_homogeneous_target.defined()) { + With tctx(config_->optional_homogeneous_target); relay_module = seq(relay_module); } else { relay_module = seq(relay_module); } // Do layout rewrite for auto-scheduler. - if (backend::IsAutoSchedulerEnabled() && targets_.size() == 1) { - const auto& target = (*targets_.begin()).second; + if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { Pass major_pass = transform::AutoSchedulerLayoutRewrite(); bool enable_layout_rewrite_targets = - target->kind->device_type == kDLCPU || target->GetAttr("device", "") == "mali"; + config_->optional_homogeneous_target->kind->device_type == kDLCPU || + config_->optional_homogeneous_target->GetAttr("device", "") == "mali"; if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { - With tctx(target); + With tctx(config_->optional_homogeneous_target); relay_module = major_pass(relay_module); // Defuse ops to fold constants, then fuse them again relay_module = transform::DefuseOps()(relay_module); @@ -416,45 +372,22 @@ class RelayBuildModule : public runtime::ModuleNode { return relay_module; } - /*! - * \brief Returns a default target to represent \p device_type. - */ - static Target CreateDefaultTarget(DLDeviceType device_type) { - std::string name = runtime::DeviceName(device_type); - if (name == "cpu") { - return Target("llvm"); - } else { - return Target(name); - } - } - /*! * \brief Compile a Relay IR module to runtime module. * * \param relay_module The Relay IR module. * \param params The parameters. */ - void BuildRelay(IRModule relay_module, - const std::unordered_map& params, - const String mod_name) { - Target target_host = GetTargetHost(); - // If no target_host has been set, we choose a default one, which is - // llvm if "codegen.LLVMModuleCreate" is accessible. - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - - // Update all the targets in the targets_ TargetMap - CheckAndUpdateHostConsistency(&targets_, &target_host); - + void BuildRelay(IRModule relay_module, const String& mod_name) { // Relay IRModule -> IRModule optimizations. - relay_module = OptimizeImpl(relay_module, params); + relay_module = OptimizeImpl(std::move(relay_module)); // Get the updated function. auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_); - executor_codegen_->Init(nullptr, targets_); + executor_codegen_->Init(nullptr, config_->legacy_target_map); executor_codegen_->Codegen(func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); @@ -467,9 +400,12 @@ class RelayBuildModule : public runtime::ModuleNode { lowered_funcs.Set(ext_dev, IRModule()); } + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + // Generate a placeholder function that attaches linked params as its arguments. - if (target_host->GetAttr("link-params").value_or(Bool(false))) { - CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; + const Target& host_target = config_->host_se_scope->target; + if (host_target->GetAttr("link-params").value_or(Bool(false))) { + CHECK(pf != nullptr) << "Unable to link-params without llvm codegen."; auto param_ids = executor_codegen_->GetParamIds(); auto link_params = Map(); for (auto param : ret_.params) { @@ -482,18 +418,19 @@ class RelayBuildModule : public runtime::ModuleNode { DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); - if (lowered_funcs.find(target_host) == lowered_funcs.end()) { - lowered_funcs.Set(target_host, IRModule(Map({}))); + if (lowered_funcs.find(host_target) == lowered_funcs.end()) { + lowered_funcs.Set(host_target, IRModule(Map({}))); } - lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), + lowered_funcs[host_target]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); } // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { - if (target_host.defined() && target_host->kind->name == "llvm") { + if (host_target->kind->name == "llvm") { + CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen."; // If we can decide the target is LLVM, we then create an empty LLVM module. - ret_.mod = (*pf)(target_host->str(), "empty_module"); + ret_.mod = (*pf)(host_target->str(), "empty_module"); } else { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining @@ -501,11 +438,11 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array{}); } } else { - ret_.mod = tvm::build(lowered_funcs, target_host); + ret_.mod = tvm::build(lowered_funcs, host_target); } auto ext_mods = executor_codegen_->GetExternalModules(); - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, executor_codegen_->GetMetadata()); // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { @@ -522,26 +459,8 @@ class RelayBuildModule : public runtime::ModuleNode { } } - private: - Target GetTargetHost() { - Target target_host = target_host_; - if (!target_host_.defined()) { - for (const auto& it : targets_) { - if (it.second->kind->device_type == kDLCPU) { - target_host = it.second; - break; - } - } - } - return target_host; - } - protected: std::unique_ptr executor_codegen_; - /*! \brief target device */ - TargetMap targets_; - /*! \brief target host device */ - tvm::Target target_host_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ @@ -552,6 +471,8 @@ class RelayBuildModule : public runtime::ModuleNode { * - aot: use the aot executor */ String executor_; + /*! \brief Collects all the targets and scopes we need during compilation. */ + CompilationConfig config_; }; runtime::Module RelayBuildCreate() { diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index ac3c835ed648..1bab2c9afc7d 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -205,7 +205,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorattrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (auto v : storage_info->device_types) { - device_types.push_back(static_cast(v)); + for (const auto& se_scope : storage_info->se_scopes) { + // TODO(mbs): Keeping only the device type. + ICHECK_GT(se_scope->device_type(), 0); + device_types.push_back(se_scope->device_type()); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { @@ -468,7 +470,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const CallNode* call_node) override { - auto props = GetOnDeviceProps(call_node); + relay::Call call = GetRef(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined()) { // See through "on_device" calls. return VisitExpr(props.body); @@ -485,6 +488,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatortuple); return {vtuple[op->index]}; } + std::vector VisitExpr_(const OpNode* op) override { LOG(FATAL) << "All OpNodes should have been expanded"; return {}; diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 4031dfdcd6e7..5fa36cc5b5db 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -53,21 +53,19 @@ struct StorageToken { size_t max_bytes{0}; /*! \brief The corresponding tensor type. */ TensorType ttype{nullptr}; - /*! \brief Device on which memory will reside. */ - Device device{kInvalidDeviceType, -1}; + /*! \brief SEScope on which the memory will reside. */ + SEScope se_scope = SEScope::FullyUnconstrained(); /*! \brief The storage id */ int64_t storage_id{-1}; - bool is_valid() const { return device.device_type != kInvalidDeviceType; } + bool is_valid() const { return !se_scope->IsFullyUnconstrained(); } - bool is_compatible(const StorageToken& that) const { - return device.device_type == that.device.device_type; - } + bool is_compatible(const StorageToken& that) const { return se_scope == that.se_scope; } std::string ToString() const { std::ostringstream os; - os << "{id: " << storage_id << ", bytes: " << max_bytes << ", type: " << PrettyPrint(ttype) - << ", device: " << device.device_type << "}"; + os << "{storage_id: " << storage_id << ", max_bytes: " << max_bytes + << ", ttype: " << PrettyPrint(ttype) << ", se_scope: " << se_scope << "}"; return os.str(); } }; @@ -169,15 +167,14 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * the result of evaluating \p op. */ void CreateToken(const ExprNode* expr_node, bool can_realloc) { - return CreateTokenOnDevice(expr_node, GetInScopeDeviceType(GetRef(expr_node)), - can_realloc); + return CreateTokenOnDevice(expr_node, GetSEScope(GetRef(expr_node)), can_realloc); } /*! * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding * the result of evaluating \p op on \p device_type. */ - virtual void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, + virtual void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) = 0; }; @@ -196,16 +193,13 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, - bool can_realloc) override { + void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) override { ICHECK(!token_map_.count(op)); std::vector tokens; for (const auto& ttype : FlattenTupleType(op->checked_type())) { - StorageToken* token = arena_->make(); + auto* token = arena_->make(); token->ttype = ttype; - // TODO(mbs): Should be TargetDevice. - token->device.device_type = device_type; - token->device.device_id = 0; + token->se_scope = se_scope; tokens.push_back(token); } token_map_[op] = tokens; @@ -261,8 +255,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; - std::vector device_types; + storage_ids.reserve(kv.second.size()); + std::vector se_scopes; + se_scopes.reserve(kv.second.size()); std::vector sid_sizes_byte; + sid_sizes_byte.reserve(kv.second.size()); for (StorageToken* tok : kv.second) { VLOG(1) << "token: " << tok->ToString(); @@ -271,10 +268,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(static_cast(tok->device.device_type)); + se_scopes.push_back(tok->se_scope); sid_sizes_byte.push_back(GetMemorySize(tok)); } - auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); + auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(se_scopes), + std::move(sid_sizes_byte)); smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. @@ -288,20 +286,20 @@ class StorageAllocator : public StorageAllocaBaseVisitor { protected: // override create token by getting token as prototype requirements. - void CreateTokenOnDevice(const ExprNode* op, DLDeviceType device_type, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) final { ICHECK(!token_map_.count(op)); auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { - ICHECK_EQ(tok->device.device_type, device_type); + ICHECK(tok->se_scope == se_scope); if (can_realloc) { tokens.push_back(Request(tok)); } else { // Allocate a new token, StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); - allocated_tok->device = tok->device; + allocated_tok->se_scope = tok->se_scope; // ensure it never get de-allocated. allocated_tok->ref_counter += 1; tokens.push_back(allocated_tok); @@ -372,7 +370,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \param size The original size. * \param word_size The element size. */ - static size_t DivRoundUp(size_t size, size_t word_size) { + static int64_t DivRoundUp(int64_t size, int64_t word_size) { return (size + word_size - 1) / word_size; } /*! @@ -398,16 +396,19 @@ class StorageAllocator : public StorageAllocaBaseVisitor { * \brief Get the memory requirement. * \param prototype The prototype token. * \return The required memory size. + * + * TODO(mbs): Gf GetMemorySizeBytes in aot_executor_codegen.cc, + * CalculateRelayExprSizeBytes in utils.cc */ - size_t GetMemorySize(StorageToken* prototype) { + static int64_t GetMemorySize(StorageToken* prototype) { TensorType ttype = prototype->ttype; ICHECK(ttype.defined()); - size_t size = 1; + int64_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; - size *= static_cast(pval[0]); + size *= pval[0]; } size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); return size; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 4835d7618a2e..6d2bd8fbec9d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -922,16 +922,12 @@ class Interpreter : public ExprFunctor, * functions needed by the rewritten module. */ IRModule Prepare(IRModule mod, CompilationConfig config) { - tec::TargetMap tec_target_map; - for (const auto& pair : config->legacy_target_map) { - tec_target_map.emplace(static_cast(pair.first->value), pair.second); - } // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), // Figure out which devices should be used to execute. // TODO(mbs): Should ignore all existing annotations when constant folding - transform::PlanDevices(config->default_primitive_se_scope->device_type()), + transform::PlanDevices(std::move(config)), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -941,8 +937,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(tec_target_map, /*module_name=*/"intrp", - [](Function func) { /* no-op */ })}); + tec::LowerTEPass(/*module_name=*/"intrp", [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 915fc22b2052..99b7d648a6e3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -45,6 +45,7 @@ #include "../op/annotation/annotation.h" #include "../op/call/call.h" +#include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -361,21 +362,6 @@ TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECo using AnalysisRemapping = std::unordered_map; -std::tuple IsDeviceCopy(const Function& func) { - if (auto call_node = func->body.as()) { - if (auto op_node = call_node->op.as()) { - if (op_node->name == "device_copy") { - auto attrs = call_node->attrs.as(); - auto dst = attrs->dst_dev_type; - auto src = attrs->src_dev_type; - return std::tuple(true, src, dst); - } - } - } - - return std::tuple(false, -1, -1); -} - /*! * \brief Rewrites call expressions to Relay functions marked as 'primitive' * to calls to the corresponding TIR primitive for the appropriate target. @@ -417,11 +403,10 @@ std::tuple IsDeviceCopy(const Function& func) { */ class LowerTensorExprMutator : public DeviceAwareExprMutator { public: - LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, ProcessFn process_fn, - const String& module_name, TECompiler compiler) + LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name, + TECompiler compiler) : DeviceAwareExprMutator(module), module_(module), - targets_(targets), process_fn_(process_fn), module_name_(module_name), compiler_(compiler), @@ -488,7 +473,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Non-External Relay Function - VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" + VLOG(1) << "lowering to target " << target->ToDebugString() << " for primitive:\n" << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); @@ -519,14 +504,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - auto device_copy = IsDeviceCopy(func); - if (std::get<0>(device_copy)) { - // Record that device copy source and destination devices so the device planner can - // still follow along. - auto source_device = std::get<1>(device_copy); - auto dst_device = std::get<2>(device_copy); - call_lowered_attrs->metadata.Set("source_device", tvm::Integer(source_device)); - call_lowered_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); + DeviceCopyProps props = GetDeviceCopyProps(func); + if (props.body.defined()) { + // Record the device copy source and destination scopes so the device planner can + // still follow along even after lowering. + call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope); + call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope); } call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); @@ -539,8 +522,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // on the host cpu irrespective of where the primitive runs. // TODO(mbs): Cleanup target handling. Target shape_target("llvm"); - VLOG(1) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; + VLOG(1) << "lowering to target " << shape_target->ToDebugString() + << " for dynamic shape function for primitive"; CCacheKey shape_key(func, shape_target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -635,9 +618,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { target = Target("ext_dev"); } else { // The target corresponding to the call_node expression's annotation. - DLDeviceType device_type = GetInScopeDeviceType(call); - // TODO(mbs): Replace device_type with target so this lookup is unnecessary. - target = GetTargetFromInteger(device_type, targets_); + SEScope se_scope = GetSEScope(call); + ICHECK(!se_scope->IsFullyUnconstrained()); + target = se_scope->target; + ICHECK(target.defined()); } Array visited_args; for (const auto& arg : call_node->args) { @@ -649,7 +633,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } IRModule module_; - TargetMap targets_; ProcessFn process_fn_; // Map from in-scope let-bound variables to Relay functions known to be // primitive. We'll rewrite these to the fresh global vars bound to the lowered @@ -695,11 +678,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } } -Pass LowerTensorExpr(TargetMap targets, const String& module_name, TECompiler compiler, +Pass LowerTensorExpr(const String& module_name, TECompiler compiler, std::function process_fn) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, targets, process_fn, module_name, compiler); + LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -716,6 +699,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa } // This is a Map> + // TODO(mbs): Collapsing SEScopes to just device type. std::unordered_map, backend::EnumClassHash> sid_workspace; // This is a Map @@ -726,15 +710,15 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // Initialize the mapping from all storage identifiers to workspace sizes, // the amount of device io, and the device constants. for (const auto& kv : storage_info_map) { - backend::StorageInfo storage_info = kv.second; - std::vector storage_ids = storage_info->storage_ids; - std::vector devices = storage_info->device_types; - - CHECK_EQ(storage_ids.size(), devices.size()); - for (uint32_t i = 0; i < devices.size(); i++) { - sid_workspace[devices[i]][storage_ids[i]] = 0; - device_io[devices[i]] = 0; - device_consts[devices[i]] = 0; + const backend::StorageInfo& storage_info = kv.second; + const std::vector& storage_ids = storage_info->storage_ids; + const std::vector& se_scopes = storage_info->se_scopes; + CHECK_EQ(storage_ids.size(), se_scopes.size()); + for (uint32_t i = 0; i < se_scopes.size(); i++) { + DLDeviceType device_type = se_scopes[i]->device_type(); + sid_workspace[device_type][storage_ids[i]] = 0; + device_io[device_type] = 0; + device_consts[device_type] = 0; } } @@ -763,18 +747,20 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa << PrettyPrint(expr->checked_type()) << std::endl << "has size " << size_bytes << " and storage info:" << std::endl << storage_info; - std::vector storage_ids = storage_info->storage_ids; - std::vector devices = storage_info->device_types; + const std::vector& storage_ids = storage_info->storage_ids; + const std::vector& se_scopes = storage_info->se_scopes; if (expr->IsInstance()) { - for (const auto& dev : devices) { - ICHECK_EQ(device_consts.count(dev), 1); - device_consts[dev] += size_bytes; + for (const auto& se_scope : se_scopes) { + DLDeviceType device_type = se_scope->device_type(); + ICHECK_EQ(device_consts.count(device_type), 1); + device_consts[device_type] += size_bytes; } } else if (expr->IsInstance() || expr.same_as(func->body)) { - CHECK_GE(devices.size(), 1) << "must be at least one device"; - for (const auto& dev : devices) { - device_io[dev] += size_bytes; + CHECK_GE(se_scopes.size(), 1) << "must be at least one device"; + for (const auto& se_scope : se_scopes) { + DLDeviceType device_type = se_scope->device_type(); + device_io[device_type] += size_bytes; } } else { // TODO(@electriclilies): This code is never being called which means sid_workspace is not @@ -784,8 +770,9 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // Here we record the largest size of the tensor // that share the same storage id, because storage_id will // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { - sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + DLDeviceType device_type = se_scopes[i]->device_type(); + if (size_bytes > sid_workspace[device_type][storage_ids[i]]) { + sid_workspace[device_type][storage_ids[i]] = size_bytes; } } } @@ -830,8 +817,9 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa constant_sizes.Set(tgt, dev_and_size.second); } - backend::FunctionInfo func_info(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, - relay_primfuncs); + backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes), + std::move(constant_sizes), std::move(tir_primfuncs), + std::move(relay_primfuncs)); VLOG(1) << "func_info: " << func_info; return std::move(func_info); } @@ -889,6 +877,7 @@ void UpdateFunctionMetadata(Function relay_func, workspace_sizes.Set(prim_fn_target, workspace_size); // Calculating size for I/O + // TODO(mbs): See also the other three utils for calculating tensor bytesize. for (auto const& param : prim_fn->params) { auto p_shape = prim_fn->buffer_map[param]->shape; int num_of_elements = 1; @@ -909,8 +898,9 @@ void UpdateFunctionMetadata(Function relay_func, relay_primfuncs.Set(prim_fn_target, relay_func); } - backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, - tir_primfuncs, relay_primfuncs); + backend::FunctionInfo fi = backend::FunctionInfo( + std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes), + std::move(tir_primfuncs), std::move(relay_primfuncs)); VLOG(1) << "FunctionInfo: " << prim_fn_var.value()->name_hint << " = " << PrettyPrint(fi); @@ -919,11 +909,11 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -IRModule LowerTE(const IRModule& module, TargetMap targets, const String& module_name, +IRModule LowerTE(const IRModule& module, const String& module_name, std::function process_fn) { TECompiler compiler; - auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module); + auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module); backend::UpdateAutoSchedulerOpWeights(compiler); @@ -968,12 +958,9 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(TargetMap targets, const String& module_name, - std::function process_fn) { - runtime::TypedPackedFunc pass_func = [=](IRModule module, - PassContext ctx) { - return LowerTE(module, targets, module_name, process_fn); - }; +Pass LowerTEPass(const String& module_name, std::function process_fn) { + runtime::TypedPackedFunc pass_func = + [=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn); }; return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index d0401e9605f7..da7333d64d46 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -173,7 +173,6 @@ Map GetPerTargetModules(IRModule mod); * to TE expressions, schedules them, and then to TIR. * * \param module The IRModule. - * \param targets The mapping for devices to targets. * \param memory_plan The memory plan used during lowering * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process @@ -181,8 +180,8 @@ Map GetPerTargetModules(IRModule mod); * \return The lowered module, see above. */ IRModule LowerTE( - const IRModule& module, TargetMap targets, backend::StaticMemoryPlan memory_plan, - const String& module_name, ProcessFn process_fn = [](Function f) {}); + const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, + ProcessFn process_fn = [](Function f) {}); /*! \brief Pass to lower an IRModule's primitive functions to TIR. * @@ -190,14 +189,12 @@ IRModule LowerTE( * to TE expressions, schedules them, and then to TIR. It annotates all functions * with their target. * - * \param targets The mapping for devices to targets. * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(TargetMap targets, const String& module_name, - std::function process_fn); +transform::Pass LowerTEPass(const String& module_name, std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 02caf56c66e6..9a1c428482e2 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -34,49 +34,56 @@ namespace backend { TVM_REGISTER_NODE_TYPE(StorageInfoNode); -StorageInfo::StorageInfo(std::vector storage_ids, std::vector device_types, - std::vector storage_sizes_in_bytes) { - auto n = make_object(); - n->storage_ids = std::move(storage_ids); - n->device_types = std::move(device_types); - n->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); - data_ = std::move(n); -} - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { const auto* node = ref.as(); - p->stream << "StorageInfoNode(\n" - << " storage_ids=["; + p->stream << "StorageInfoNode(" + << "storage_ids=["; for (auto id : node->storage_ids) { - p->stream << id << ", "; + p->stream << id << ","; } - p->stream << "],\n device_types=["; - for (auto device_type : node->device_types) { - p->stream << device_type << ", "; + p->stream << "], se_scopes=["; + for (const auto& se_scope : node->se_scopes) { + p->stream << se_scope << ","; } - p->stream << "],\n storage_size_in_bytes=["; + p->stream << "], storage_size_in_bytes=["; for (auto bytes : node->storage_sizes_in_bytes) { - p->stream << bytes << ", "; + p->stream << bytes << ","; } p->stream << "])"; }); +StorageInfo::StorageInfo(std::vector storage_ids, std::vector se_scopes, + std::vector storage_sizes_in_bytes) { + ICHECK_EQ(storage_ids.size(), se_scopes.size()); + ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size()); + auto node = make_object(); + node->storage_ids = std::move(storage_ids); + node->se_scopes = std::move(se_scopes); + node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); + data_ = std::move(node); +} + +// This is the legacy interface for devices as DLDeviceTypes (represented by integers) TVM_REGISTER_GLOBAL("relay.ir.StorageInfo") - .set_body_typed([](const Array& sids, const Array& dev_types, + .set_body_typed([](const Array& sids, const Array& device_types, const Array& sizes_in_bytes) { - std::vector sids_v, sizes_v; - std::vector dev_types_v; + std::vector sids_v; + sids_v.reserve(sids.size()); for (auto s : sids) { sids_v.push_back(s); } - for (auto d : dev_types) { - dev_types_v.push_back(static_cast(static_cast(d))); + std::vector se_scopes_v; + se_scopes_v.reserve(device_types.size()); + for (const auto& device_type : device_types) { + se_scopes_v.emplace_back(SEScope::ForDeviceType(device_type)); } + std::vector size_in_bytes_v; + size_in_bytes_v.reserve(sizes_in_bytes.size()); for (auto s : sizes_in_bytes) { - sizes_v.push_back(s); + size_in_bytes_v.push_back(s); } - return StorageInfo(sids_v, dev_types_v, sizes_v); + return StorageInfo(std::move(sids_v), std::move(se_scopes_v), std::move(size_in_bytes_v)); }); TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) { @@ -87,10 +94,11 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageI return ids; }); +// This is the legacy interface for devices as DLDeviceTypes (represented by integers) TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { Array device_types; - for (auto id : si->device_types) { - device_types.push_back(id); + for (const auto& se_scope : si->se_scopes) { + device_types.push_back(se_scope->device_type()); } return device_types; }); @@ -116,7 +124,8 @@ TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan") return StaticMemoryPlan(expr_to_storage_info); }); -// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc +// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc, GetMemorySize in +// graph_plan_memory.cc int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { if (expr_type->IsInstance()) { auto tuple_type = Downcast(expr_type); @@ -166,7 +175,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); -Array GetPassPrefix(const Map& targets, bool is_vm) { +Array GetPassPrefix(bool is_homegeneous, bool is_vm) { Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); @@ -175,7 +184,7 @@ Array GetPassPrefix(const Map& targets, bool is pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { + if (is_homegeneous) { pass_seqs.push_back(transform::Legalize()); } @@ -217,8 +226,8 @@ Array GetPassPrefix(const Map& targets, bool is pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { + // Alter layout transformation is currently only applied to homogeneous execution. + if (is_homegeneous) { if (!is_vm) { pass_seqs.push_back(transform::InferType()); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 16cbe0e8dbca..4224a99c2628 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -57,15 +58,17 @@ namespace backend { using Pass = tvm::transform::Pass; /*! - * \brief The static storage information produced by memory planning. + * \brief The static storage information for each Tensor in the result of a Relay expression + * (as per relay::FlattenTupleType). */ class StorageInfoNode : public Object { public: + // TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout. /*! \brief The set of storage ids where the expression is stored. */ std::vector storage_ids; - /* \brief The type of "virtual devices" these expressions are stored on. */ - std::vector device_types; - /* \brief The sizes of each storage element. */ + /* \brief The SEScopes these expressions are stored within. */ + std::vector se_scopes; + /* \brief The sizes of each storage element, in bytes. */ std::vector storage_sizes_in_bytes; // TODO(@jroesch): expose the fields @@ -78,7 +81,7 @@ class StorageInfoNode : public Object { /*! \brief The storage information for a single expression. */ class StorageInfo : public ObjectRef { public: - StorageInfo(std::vector storage_ids, std::vector device_types, + StorageInfo(std::vector storage_ids, std::vector se_scopes, std::vector storage_sizes_in_bytes); TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode); }; @@ -442,11 +445,11 @@ inline bool IsMetaScheduleEnabled() { * difference. This function unifies the shared optimization pass prefix between vm and graph * runtime, and returns the pass prefix given the backend type. * - * \param targets The device type to `Target` mapping. - * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. + * \param is_homogenous True if all primitives are to be executed on the same device and target. + * \param is_vm True if passes are to be used for the vm executor. * \return An array of passes. */ -Array GetPassPrefix(const TargetMap& targets, bool is_vm); +Array GetPassPrefix(bool is_homogenous, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 02477d05673b..b6ecd7a4b7ca 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -81,6 +81,9 @@ using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; +/*! \brief The host device is always stored at device index 0. */ +constexpr Index kHostDeviceIndex = 0; + // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -93,9 +96,9 @@ using MatchValuePtr = std::shared_ptr; // A runtime object that resides in a register struct RegisterValue : MatchValue { // The register num - RegName rergister_num; + RegName register_num; - explicit RegisterValue(RegName reg) : rergister_num(reg) {} + explicit RegisterValue(RegName reg) : register_num(reg) {} ~RegisterValue() {} }; @@ -227,44 +230,17 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } -/*! - * \brief Create a default type. - * \param device_type The device type index. - * \return the default target for the device. - */ -Target CreateDefaultTarget(int device_type) { - std::string name = runtime::DeviceName(device_type); - if (name == "cpu") return Target("llvm"); - if (name == "cuda") return Target("cuda"); - return Target(name); -} - -int GetFallbackDevice() { - transform::PassContext pass_ctx = PassContext::Current(); - Optional opt_fallback_dev = - pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); - auto fallback_dev = opt_fallback_dev.value(); - ICHECK_GT(fallback_dev->value, 0U); - return fallback_dev->value; -} - class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, SEScope host_se_scope) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), context_(context), - target_host_(target_host) { - CheckAndUpdateHostConsistency(&targets, &target_host); - for (const auto& it : targets) { - targets_[it.first->value] = it.second; - } - target_host_ = target_host; - } + host_se_scope_(std::move(host_se_scope)) {} VMFunction Compile(const GlobalVar& var, const Function& func) { - std::vector params_device_type; + std::vector param_device_indexes; if (IsClosure(func)) { // After lifting we'll have functions of the form: // fn(closure args) { fn(lifted function args) { body } } @@ -273,16 +249,21 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); std::vector params; - std::vector param_device_types; + std::vector param_se_scopes; params.reserve(func->params.size() + inner_func->params.size()); - param_device_types.reserve(func->params.size() + inner_func->params.size()); + param_se_scopes.reserve(func->params.size() + inner_func->params.size()); + param_device_indexes.reserve(func->params.size() + inner_func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { params.emplace_back(func->params[i]); - params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i); + param_se_scopes.push_back(param_se_scope); + param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } for (size_t i = 0; i < inner_func->params.size(); ++i) { params.emplace_back(inner_func->params[i]); - params_device_type.push_back(GetFunctionParamDeviceType(inner_func.get(), i)); + SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i); + param_se_scopes.push_back(param_se_scope); + param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } std::vector type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); @@ -294,22 +275,17 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, type_params, func->attrs, func->span); - VisitExpr(MaybeFunctionOnDevice(flattened_func, params_device_type, - GetFunctionResultDeviceType(inner_func.get()))); + VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes, + GetFunctionResultSEScope(inner_func.get()))); } else { - params_device_type.reserve(func->params.size()); + param_device_indexes.reserve(func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { - params_device_type.push_back(GetFunctionParamDeviceType(func.get(), i)); + param_device_indexes.push_back(GetDeviceIndex(GetFunctionParamSEScope(func.get(), i))); } VisitExpr(func); } - std::vector params_device_type_index; - params_device_type_index.reserve(params_device_type.size()); - for (auto device_type : params_device_type) { - params_device_type_index.push_back(static_cast(device_type)); - } return VMFunction(var->name_hint, params_, instructions_, registers_num_, - params_device_type_index); + std::move(param_device_indexes)); } /*! \brief Attrs objects for each op. */ @@ -352,6 +328,48 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { instructions_.push_back(instr); } + /*! + * \brief Returns the "device index" to represent \p se_scope for primitives + * in emitted code. Note that the host device is always at index 0. + */ + Index GetDeviceIndex(const SEScope& se_scope) { + VLOG(2) << "getting device index for " << se_scope; + auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); + if (itr != context_->se_scopes_.end()) { + VLOG(2) << "reusing existing scope"; + return std::distance(context_->se_scopes_.begin(), itr); + } + + ICHECK_GT(context_->se_scopes_.size(), 0); + ICHECK_NE(se_scope, host_se_scope_); + + if (se_scope->device_type() == context_->se_scopes_.front()->device_type()) { + // It's ok if we see distinct scopes which share the host device type. This is because + // we allow the SEScope for the host to be different from the SEScope for primitive + // operations which happen to be, eg, on the CPU. + return 0; + } + + // However, otherwise we allow at most one SEScope per device type. + // TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy + // instructions can do the right thing. + itr = std::find_if(context_->se_scopes_.begin() + 1, context_->se_scopes_.end(), + [&se_scope](const SEScope& existing_se_scope) { + return existing_se_scope->device_type() == se_scope->device_type(); + }); + CHECK(itr == context_->se_scopes_.end()) + << "The VM does not currently support using more than one device with the same device type " + "for primitives, however the program is using the distinct scopes " + << se_scope << " and " << *itr << " of device type " << se_scope->device_type(); + + ICHECK(se_scope != host_se_scope_); + Index index = context_->se_scopes_.size(); + VLOG(2) << "adding new scope"; + context_->se_scopes_.push_back(se_scope); + + return index; + } + using DeviceAwareExprFunctor::VisitExpr_; void VisitExpr_(const ConstantNode* const_node) final { @@ -359,7 +377,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { NDArray data = const_node->data; size_t konst_idx = context_->constants.size(); auto con = GetRef(const_node); - context_->const_device_type.push_back(GetInScopeDeviceType(con)); + context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con))); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } @@ -463,7 +481,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - tec::CCacheKey key(func, target_host_); + tec::CCacheKey key(func, host_se_scope_->target); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -498,7 +516,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { argument_registers)); } - void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { + void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs, + SEScope se_scope) { std::vector argument_registers; ICHECK(func->HasNonzeroAttr(attr::kPrimitive)) @@ -531,13 +550,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { - int dev_type = GetInScopeDeviceType(func); - if (targets_.count(dev_type) == 0) { - target = CreateDefaultTarget(dev_type); - } else { - target = targets_[dev_type]; - } + target = se_scope->target; } + ICHECK(target.defined()) << "No target for function:" << std::endl << PrettyPrint(func); tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; @@ -577,9 +592,11 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { OpMatch matcher; matcher .Match("vm.invoke_tvm_op", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { ICHECK_EQ(args.size(), 3); - EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); + EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2], + GetSEScope(GetRef(call_node))); }) .Match("memory.alloc_tensor", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -638,7 +655,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto dtype = alloc_attrs->dtype; Emit(Instruction::AllocStorage(size_register, alignment, dtype, - alloc_attrs->device_type, NewRegister())); + GetDeviceIndex(alloc_attrs->se_scope), + NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -670,17 +688,17 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); }) .Match("device_copy", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { ICHECK_EQ(args.size(), 1U); this->VisitExpr(args[0]); auto src_reg = last_register_; auto device_copy_attrs = attrs.as(); ICHECK(device_copy_attrs != nullptr) << "Must be the device copy attrs"; - Index src_device_type = device_copy_attrs->src_dev_type; - Index dst_device_type = device_copy_attrs->dst_dev_type; - Emit(Instruction::DeviceCopy(src_reg, src_device_type, dst_device_type, - NewRegister())); + Emit(Instruction::DeviceCopy( + src_reg, GetDeviceIndex(device_copy_attrs->src_se_scope), + GetDeviceIndex(device_copy_attrs->dst_se_scope), NewRegister())); }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -780,7 +798,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { RegName CompileMatchValue(MatchValuePtr val) { if (std::dynamic_pointer_cast(val)) { auto r = std::dynamic_pointer_cast(val); - return r->rergister_num; + return r->register_num; } else { auto path = std::dynamic_pointer_cast(val); auto p = CompileMatchValue(path->parent); @@ -857,18 +875,15 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t registers_num_; /*! \brief Global shared meta data */ VMCompilerContext* context_; - /*! \brief Target devices. */ - std::unordered_map targets_; - /*! \brief Host target. */ - Target target_host_; + /*! \brief SEScope for data and computation which must reside on a CPU. */ + SEScope host_se_scope_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 3); - IRModule mod = args[0]; - this->Lower(mod, args[1], args[2]); + this->Lower(args[0], args[1], args[2]); }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -908,15 +923,16 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, - const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) { exec_ = make_object(); - targets_ = targets; - target_host_ = target_host; - CheckAndUpdateHostConsistency(&targets_, &target_host_); + config_ = CompilationConfig(PassContext::Current(), std::move(targets), std::move(target_host)); + + // The first device is always for the host. + CHECK(context_.se_scopes_.empty()); + context_.se_scopes_.push_back(config_->host_se_scope); // Run the optimizations necessary to target the VM. - context_.module = OptimizeModule(mod, targets_, target_host_); + context_.module = OptimizeModuleImpl(std::move(mod)); // Populate the global map. // @@ -932,7 +948,7 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_); + VMFunctionCompiler func_compiler(&context_, config_->host_se_scope); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -946,17 +962,27 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, } } + // Populate virtual devices and the host device index. + for (const auto& se_scope : context_.se_scopes_) { + ICHECK(!se_scope->IsFullyUnconstrained()); + ICHECK_GT(se_scope->device_type(), 0); + // TODO(mbs): We forget the memory scope. + exec_->virtual_devices.push_back( + Device{/*device_type=*/se_scope->device_type(), /*device_id=*/se_scope->virtual_device_id}); + } + exec_->host_device_index = kHostDeviceIndex; + // populate constants - for (auto data : context_.constants) { + for (const auto& data : context_.constants) { exec_->constants.push_back(data); } - for (auto i : context_.const_device_type) { - exec_->const_device_type.push_back(i); + for (auto index : context_.const_device_indexes) { + exec_->const_device_indexes.push_back(index); } // update global function map - for (auto gv : context_.global_map) { + for (const auto& gv : context_.global_map) { exec_->global_map.insert({gv.first->name_hint, gv.second}); } @@ -966,22 +992,21 @@ void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } -#if USE_RELAY_DEBUG - for (const auto& vm_func : exec_->functions) { - VLOG(1) << vm_func << "-------------"; - } -#endif // USE_RELAY_DEBUG + VLOG(1) << std::endl + << "-------------------------------------------------" << std::endl + << exec_->GetVirtualDevices() << exec_->GetConstants() << exec_->GetBytecode() + << "-------------------------------------------------"; backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { +transform::Sequential MemoryOpt(const SEScope& cpu_se_scope) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -990,7 +1015,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) pass_seqs.push_back(transform::FuseOps()); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Fuse the shape functions. pass_seqs.push_back(transform::FuseOps()); @@ -1008,7 +1033,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) pass_seqs.push_back(transform::FuseOps()); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); + pass_seqs.push_back(transform::ManifestAlloc(cpu_se_scope)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -1018,15 +1043,22 @@ transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) // instructions need to access to constant // pass_seqs.push_back(transform::LiftConstants()); - return transform::Sequential(pass_seqs); + return transform::Sequential(std::move(pass_seqs)); +} + +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, + const Target& target_host) { + config_ = CompilationConfig(PassContext::Current(), targets, target_host); + // The first device always corresponds to the host. + CHECK(context_.se_scopes_.empty()); + context_.se_scopes_.push_back(config_->host_se_scope); + // TODO(mbs): exec_ is not allocated. What is the API here? + CHECK(exec_ == nullptr); + return OptimizeModuleImpl(std::move(mod)); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, - const Target& target_host_arg) { +IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VMCompiler::OptimizeModule"; - TargetMap targets = targets_arg; - Target target_host = target_host_arg; - CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); ICHECK(base_func->IsInstance()) @@ -1036,29 +1068,24 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs = relay::backend::GetPassPrefix(targets, true); + Array pass_seqs = relay::backend::GetPassPrefix( + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); - // TODO(mbs): Reconcile with relay/backend/build_module.cc - DLDeviceType default_device_type; - if (targets_arg.size() == 1) { - default_device_type = - static_cast(static_cast((*targets_arg.begin()).first->value)); - } else { - default_device_type = static_cast(GetFallbackDevice()); - } - pass_seqs.push_back(PlanDevices(default_device_type)); + // Always plan devices so the remaining passes don't need to distinguish homogeneous vs + // hetrogeneous execution. + pass_seqs.push_back(transform::PlanDevices(config_)); pass_seqs.push_back(transform::FuseOps()); // Do layout rewrite for auto-scheduler. transform::PassContext pass_ctx = PassContext::Current(); - if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) { - const auto& target = (*targets.begin()).second; + if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { Pass major_pass = transform::AutoSchedulerLayoutRewrite(); bool enable_layout_rewrite_targets = - target->kind->device_type == kDLCPU || target->GetAttr("device", "") == "mali"; + config_->optional_homogeneous_target->kind->device_type == kDLCPU || + config_->optional_homogeneous_target->GetAttr("device", "") == "mali"; if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { - With tctx(target); + With tctx(config_->optional_homogeneous_target); pass_seqs.push_back(major_pass); // Defuse ops to fold constants, then fuse them again pass_seqs.push_back(transform::DefuseOps()); @@ -1079,18 +1106,18 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(target_host, targets)); + pass_seqs.push_back(MemoryOpt(config_->host_se_scope)); pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::LabelOps()); transform::Sequential seq(pass_seqs); tvm::With ctx(pass_ctx); - if (targets.size() == 1) { - const auto& it = targets.begin(); - With tctx((*it).second); - return seq(mod); + if (config_->optional_homogeneous_target.defined()) { + With tctx(config_->optional_homogeneous_target); + return seq(std::move(mod)); + } else { + return seq(std::move(mod)); } - return seq(mod); } void VMCompiler::PopulateGlobalMap() { @@ -1137,13 +1164,14 @@ void VMCompiler::Codegen() { runtime::Module lib; if (funcs.size() > 0) { - lib = tvm::build(funcs, target_host_); + lib = tvm::build(funcs, config_->host_target); } else { // There is no function handled by TVM. We create a virtual main module // to make sure a DSO module will be also available. lib = codegen::CSourceModuleCreate(";", "", Array{}); } - lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata()); + lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, + runtime::Metadata()); exec_->SetLib(lib); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 5b51d7821d78..2edec70d5c3b 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -78,17 +78,21 @@ struct VMCompilerContext { tec::TECompiler compiler; // List of constants std::vector constants; - // Device type for constants - std::vector const_device_type; + // Device indexes for constants + std::vector const_device_indexes; // List of cached functions std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; + // The SEScopes corresponding to each device index. The first device always corresponds + // to the host device, and all remaining devices are for the primitive operations. + std::vector se_scopes_; }; class VMCompiler : public runtime::ModuleNode { public: - virtual ~VMCompiler() {} + VMCompiler() = default; + virtual ~VMCompiler() = default; virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); @@ -110,7 +114,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, TargetMap targets, Target target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -128,6 +132,8 @@ class VMCompiler : public runtime::ModuleNode { */ IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); + IRModule OptimizeModuleImpl(IRModule mod); + /*! * \brief Populate the global function names in a map where the value is used * as the index by the VMFunctions. @@ -135,10 +141,8 @@ class VMCompiler : public runtime::ModuleNode { void PopulateGlobalMap(); protected: - /*! \brief Target devices. */ - TargetMap targets_; - /*! \brief Target host device. */ - tvm::Target target_host_; + /*! \brief Targets and scopes needed for compilation. */ + CompilationConfig config_; /*! \brief Global shared meta data */ VMCompilerContext context_; /*! \brief Compiled executable. */ diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index d9a2b8b91fa3..ffd0e466eb24 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -112,7 +112,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; - std::vector captured_var_device_types; + std::vector captured_var_se_scopes; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -120,7 +120,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { continue; } captured_vars.push_back(var); - captured_var_device_types.push_back(GetInScopeDeviceType(var)); + captured_var_se_scopes.push_back(GetSEScope(var)); } // Freshen all the captured vars. @@ -132,7 +132,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { rebinding_map.Set(free_var, var); } - DLDeviceType result_device_type = GetInScopeDeviceType(func_node->body); + SEScope result_se_scope = GetSEScope(func_node->body); if (recursive) { if (!captured_vars.empty()) { @@ -195,8 +195,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); - lifted_func = - MaybeFunctionOnDevice(lifted_func, captured_var_device_types, result_device_type); + lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope); lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 9a2297a75962..e9441f1b3e58 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -32,6 +32,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { @@ -529,11 +530,11 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - std::vector new_param_device_types; + std::vector new_param_se_scopes; for (size_t i = 0; i < func->params.size(); ++i) { if (!args_map.count(func->params[i])) { new_params.push_back(func->params[i]); - new_param_device_types.push_back(GetFunctionParamDeviceType(func, i)); + new_param_se_scopes.push_back(GetFunctionParamSEScope(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { @@ -541,7 +542,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); + ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -549,19 +550,19 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); - if (GetFunctionResultDeviceType(func) != kInvalidDeviceType) { + if (!GetFunctionResultSEScope(func)->IsFullyUnconstrained()) { // TODO(mbs): The function has been annotated with a device, which means we are supposed // to be preserving device annotations on every transformation. However there's no // such context for the free vars in args_map. LOG(WARNING) << "introduced free var '" << PrettyPrint(v) << "' into function body but no device is known for it"; } - new_param_device_types.push_back(kInvalidDeviceType); + new_param_se_scopes.push_back(SEScope::FullyUnconstrained()); } } ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_device_types, GetFunctionResultDeviceType(func)); + ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 27b61333c9eb..bd3162dfde86 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -38,158 +38,6 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); - -const Op& OnDeviceOp() { - static const Op& op = Op::Get("on_device"); - return op; -} - -Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { - auto attrs = make_object(); - attrs->device_type = device_type; - attrs->is_fixed = is_fixed; - Span span = expr->span; - return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); -} - -Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { - if (device_type == kInvalidDeviceType) { - // Undefined signals no annotation is required. - return expr; - } - if (expr->IsInstance() || expr->IsInstance()) { - // These operators are device polymorphic so no annotation is required. - // TODO(mbs): The device planning pass does NOT currently support device polymorphism for - // constructors, so we could remove them from this condition. However most constructors - // accept type parameters, and it is not well-formed Relay to simply wrap such a - // constructor in an "on_device" call. So we'll pretend they are device polymorphic to - // avoid that difficultly. Overall ADTs need more work to be fully supported. - return expr; - } - if (expr->IsInstance() || expr->IsInstance()) { - // The device can be recovered from the binding site of the global or local variable. - return expr; - } - if (expr->IsInstance()) { - // If a primitive function then it is device polymorphic. Otherwise the device is captured - // by the function's attributes. - return expr; - } - OnDeviceProps props = GetOnDeviceProps(expr); - if (props.body.defined()) { - // Don't nest on_devices. - // If the inner and outer device types differ then we need to be careful: - // - If the inner on_device is_fixed then it disagrees with the outer. - // - If the outer on_device is_fixed then it implies a hidden device_copy - // Otherwise just use the inner device type and ignore the outer. - ICHECK(props.device_type == device_type || (!is_fixed && !props.is_fixed)); - return OnDevice(props.body, device_type, is_fixed || props.is_fixed); - } - return OnDevice(expr, device_type, is_fixed); -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr expr, int device_type, bool is_fixed) { - return OnDevice(expr, static_cast(device_type), is_fixed); - }); - -RELAY_REGISTER_OP("on_device") - .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") - .set_support_level(10) - .add_type_rel("Identity", IdentityRel) - .set_attrs_type_key("relay.attrs.OnDeviceAttrs") - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("TNonComputational", true); - -OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { - if (call_node->op == OnDeviceOp()) { - ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; - ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; - const auto* on_device_attrs = call_node->attrs.as(); - ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; - auto device_type = static_cast(on_device_attrs->device_type); - // Follow nesting: - // on_device(on_device(expr, device_type=1), device_type=2) == {expr, 1} - auto inner = GetOnDeviceProps(call_node->args[0]); - if (inner.body.defined()) { - return {inner.body, inner.device_type, on_device_attrs->is_fixed || inner.is_fixed}; - } else { - return {call_node->args[0], device_type, on_device_attrs->is_fixed}; - } - } - return {}; -} - -OnDeviceProps GetOnDeviceProps(const Expr& expr) { - if (const auto* call_node = expr.as()) { - return GetOnDeviceProps(call_node); - } - return {}; -} - -Function FunctionOnDevice(Function function, Array param_device_types, - Integer result_device_type) { - return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, - {tvm::attr::kResultDeviceType, result_device_type}}); -} - -Function FunctionOnDevice(Function function, const std::vector& param_device_types, - DLDeviceType result_device_type) { - Array arr; - arr.reserve(param_device_types.size()); - for (const auto device_type : param_device_types) { - arr.push_back(static_cast(device_type)); - } - return FunctionOnDevice(std::move(function), std::move(arr), - static_cast(result_device_type)); -} - -Function MaybeFunctionOnDevice(Function function, - const std::vector& param_device_types, - DLDeviceType result_device_type) { - if (std::all_of(param_device_types.begin(), param_device_types.end(), - [](DLDeviceType type) { return type == kInvalidDeviceType; }) && - result_device_type == kInvalidDeviceType) { - return function; - } - return FunctionOnDevice(function, param_device_types, result_device_type); -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") - .set_body_typed([](Function function, Array param_device_types, - int result_device_type) { - return FunctionOnDevice(function, param_device_types, - static_cast(result_device_type)); - }); - -DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { - auto opt_integer = function_node->GetAttr(tvm::attr::kResultDeviceType); - if (!opt_integer) { - // No annotation. - return kInvalidDeviceType; - } - return static_cast(opt_integer.value()->value); -} - -DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { - ICHECK_LT(i, function_node->params.size()) - << "param index " << i << " out of range for function of arity " - << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamDeviceTypes); - if (!opt_array) { - // No annotation. - return kInvalidDeviceType; - } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return static_cast(opt_array.value()[i]->value); -} - Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index d772df9b023a..1675b7281ebb 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -34,112 +34,6 @@ namespace tvm { namespace relay { -/*! \brief Returns the "on_device" operator. */ -const Op& OnDeviceOp(); - -/*! - * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. - * - * See \p OnDeviceAttrs for an overview. - */ -Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); - -/*! - * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed if the - * device for \p expr cannot otherwise be recovered by the lexical scoping convention. This means - * we will NOT wrap if: - * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations - * already in play. - * - \p expr is an operator or primitive function literal. These are device polymorphic. - * - \p expr is a non-primitive function literal. The device is captured by the - * "result_device_type" attribute on the function itself. - * - \p expr is a global var. The device is on the function attributes the global is bound to. - * - \p expr is a local var. The device is tracked by the device aware visitors for us. - * - \p expr is a constructor. These should eventually be device polymorphic but are currently - * in an in-between state at the moment. - */ -Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); - -/*! \brief Result of \p GetOnDeviceProps. */ -struct OnDeviceProps { - Expr body; // = null - DLDeviceType device_type = kInvalidDeviceType; - bool is_fixed = false; - - OnDeviceProps() = default; - - OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed) - : body(body), device_type(deviceType), is_fixed(isFixed) {} -}; - -/*! - * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is - * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p - * false. - */ -OnDeviceProps GetOnDeviceProps(const CallNode* call_node); - -/*! - * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an - * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false. - */ -OnDeviceProps GetOnDeviceProps(const Expr& expr); - -/*! - * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns - * \p expr directly. - */ -inline Expr IgnoreOnDevice(const Expr& expr) { - OnDeviceProps props = GetOnDeviceProps(expr); - return props.body.defined() ? props.body : expr; -} - -/*! - * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through - * any "on_device" annotations. - */ -template -const NodeType* AsIgnoringOnDevice(const Expr& expr) { - const auto* node = expr.as(); - if (node != nullptr) { - return node; - } - OnDeviceProps props = GetOnDeviceProps(expr); - if (!props.body.defined()) { - return nullptr; - } - return props.body.as(); -} - -/*! - * \brief Returns \p function annotated with "param_device_types" and "result_device_type" - * attributes capturing parameter and result devices types respectively. - */ -Function FunctionOnDevice(Function function, Array param_device_types, - Integer body_device_type); -Function FunctionOnDevice(Function function, const std::vector& param_device_types, - DLDeviceType body_device_type); - -/*! - * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and - * result device types are \p kInvalidDeviceType. - */ -Function MaybeFunctionOnDevice(Function function, - const std::vector& param_device_types, - DLDeviceType result_device_type); - -/*! - * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType - * if function does not have "result_device_type" annotation. - */ -DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); - -/*! - * \brief Returns the device type for the \p i'th parameter of \p function_node, or - * \p kInvalidDeviceType if function does not have "param_device_types" annotation. - */ -DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); - /*! \brief Wraps \p data in a "stop_fusion" annotation. */ Expr StopFusion(Expr data); diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index 9106b95c9217..538264ce9688 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -31,6 +31,8 @@ #include #include +#include + #include "../../transforms/infer_layout_utils.h" #include "../annotation/annotation.h" #include "../call/call.h" @@ -47,29 +49,27 @@ const Op& DeviceCopyOp() { return op; } -Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { +Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { + ICHECK(!src_se_scope->IsFullyUnconstrained()); + ICHECK(!dst_se_scope->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; + attrs->src_se_scope = std::move(src_se_scope); + attrs->dst_se_scope = std::move(dst_se_scope); Span span = expr->span; - return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span); + return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + std::move(span)); } -Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { - if (src_dev_type == dst_dev_type) { +TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy").set_body_typed(DeviceCopy); + +Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { + if (src_se_scope == dst_se_scope) { + // No copy needed. return expr; } - ICHECK_NE(src_dev_type, kInvalidDeviceType); - ICHECK_NE(dst_dev_type, kInvalidDeviceType); - return DeviceCopy(expr, src_dev_type, dst_dev_type); + return DeviceCopy(std::move(expr), std::move(src_se_scope), std::move(dst_se_scope)); } -TVM_REGISTER_GLOBAL("relay.op._make.device_copy") - .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) { - return DeviceCopy(expr, static_cast(src_dev_type), - static_cast(dst_dev_type)); - }); - RELAY_REGISTER_OP("device_copy") .describe(R"code( Copy data from one tensor to another. The source and destination might be @@ -96,16 +96,14 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { ICHECK(call_node->attrs.defined()) << "device_copy requires attributes"; const auto* device_copy_attrs = call_node->attrs.as(); ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; - auto src_dev_type = static_cast(device_copy_attrs->src_dev_type); - auto dst_dev_type = static_cast(device_copy_attrs->dst_dev_type); // Follow nesting: - // device_copy(device_copy(expr, src_dev_type=1, dst_dev_type=2), - // src_dev_type=2, dst_dev_type=3) ==> {expr, 1, 3} + // device_copy(device_copy(expr, src_se_scope=S, dst_se_scope=T), + // src_se_scope=T, dst_se_scope=U) ==> {expr, S, U} auto inner = GetDeviceCopyProps(call_node->args[0]); if (inner.body.defined()) { - return {inner.body, inner.src_dev_type, inner.dst_dev_type}; + return {inner.body, inner.src_se_scope, device_copy_attrs->dst_se_scope}; } else { - return {call_node->args[0], src_dev_type, dst_dev_type}; + return {call_node->args[0], device_copy_attrs->src_se_scope, device_copy_attrs->dst_se_scope}; } } else if (call_node->op == CallLoweredOp()) { /* Get device props for a TIR function */ diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index d21fdb6abe19..3b40f410e53b 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -28,6 +28,8 @@ #include #include +#include + namespace tvm { namespace relay { @@ -35,41 +37,43 @@ namespace relay { const Op& DeviceCopyOp(); /*! - * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on - * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and + * stored at \p src_se_scope but then copied to \p dst_se_scope. */ -Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); /*! - * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on - * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. - * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and + * stored at \p src_se_scope but then copied to \p dst_se_scope.However, return \p expr + * directly if \p src_se_scope and \p dst_se_scope are (structurally) the same. */ -Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); /*! \brief Result of \p GetDeviceCopyProps. */ struct DeviceCopyProps { Expr body; // = null - DLDeviceType src_dev_type = kInvalidDeviceType; - DLDeviceType dst_dev_type = kInvalidDeviceType; + SEScope src_se_scope = SEScope::FullyUnconstrained(); + SEScope dst_se_scope = SEScope::FullyUnconstrained(); DeviceCopyProps() = default; - DeviceCopyProps(const Expr& body, DLDeviceType srcDevType, DLDeviceType dstDevType) - : body(body), src_dev_type(srcDevType), dst_dev_type(dstDevType) {} + DeviceCopyProps(Expr body, SEScope src_se_scope, SEScope dst_se_scope) + : body(std::move(body)), + src_se_scope(std::move(src_se_scope)), + dst_se_scope(std::move(dst_se_scope)) {} }; /*! - * \brief Returns the body expression, source, and destination device types for \p call_node if it - * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType - * device types. + * \brief Returns the body expression, source, and destination \p SEScopes for \p call_node + * if it is a "device_copy" CallNode. Otherwise returns the null expression and unconstrained + * device and scopes. */ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); /*! - * \brief Returns the body expression, source, and destination device types for \p expr if it - * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType - * device types. + * \brief Returns the body expression, source, and destination \p SEScopes for \p expr if it + * is a "device_copy" Call. Otherwise returns the null expression and unconstrained device and + * scopes. */ DeviceCopyProps GetDeviceCopyProps(const Expr& expr); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 08e92b31965e..0574fd50f4b6 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -32,12 +32,14 @@ #include #include +#include #include #include "../../transforms/infer_layout_utils.h" #include "../annotation/annotation.h" #include "../op_common.h" #include "../type_relations.h" +#include "on_device.h" namespace tvm { namespace relay { @@ -48,13 +50,12 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); // The passing value in attrs and args doesn't seem super great. // We should consider a better solution, i.e the type relation // being able to see the arguments as well? -Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint) { +Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint) { auto attrs = make_object(); attrs->dtype = dtype_hint; - attrs->device_id = dev.device_id; - attrs->device_type = dev.device_type; + attrs->se_scope = std::move(se_scope); static const Op& op = Op::Get("memory.alloc_storage"); - return Call(op, {size, alignment}, Attrs(attrs), {}); + return Call(op, {std::move(size), std::move(alignment)}, Attrs(std::move(attrs)), {}); } TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage").set_body_typed(AllocStorage); diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index 558c409782f5..618044a9f2ca 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -25,6 +25,8 @@ #ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_ #define TVM_RELAY_OP_MEMORY_MEMORY_H_ +#include + #include #include "tvm/relay/expr.h" @@ -32,7 +34,7 @@ namespace tvm { namespace relay { -Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint); +Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc new file mode 100644 index 000000000000..9541d4122a2f --- /dev/null +++ b/src/relay/op/memory/on_device.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/op/memory/on_device.cc + * \brief Helpers for working with the "on_device" 'annotation' call. + */ + +#include "./on_device.h" + +#include +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); + +const Op& OnDeviceOp() { + static const Op& op = Op::Get("on_device"); + return op; +} + +Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) { + ICHECK(!se_scope->IsFullyUnconstrained()); + auto attrs = make_object(); + attrs->se_scope = std::move(se_scope); + attrs->is_fixed = is_fixed; + Span span = expr->span; + return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + std::move(span)); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); + +Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) { + if (se_scope->IsFullyUnconstrained()) { + // Nothing to annotate with. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // These operators are device polymorphic so no annotation is required. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // The device can be recovered from the binding site of the global or local variable. + return expr; + } + if (expr->IsInstance()) { + // If a primitive function then it is device polymorphic. Otherwise the device is captured + // by the function's "result_se_scope" attribute. + return expr; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (props.body.defined()) { + // Don't nest on_devices. + // If the inner and outer device types differ then we need to be careful: + // - If the inner on_device is_fixed then it disagrees with the outer. + // - If the outer on_device is_fixed then it implies a hidden device_copy + // Otherwise just use the inner device type and ignore the outer. + ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed)); + return OnDevice(props.body, se_scope, is_fixed || props.is_fixed); + } + return OnDevice(expr, std::move(se_scope), is_fixed); +} + +RELAY_REGISTER_OP("on_device") + .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.OnDeviceAttrs") + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("TNonComputational", true); + +OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + if (call_node->op == OnDeviceOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; + ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; + const auto* on_device_attrs = call_node->attrs.as(); + ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; + // Follow nesting: + // on_device(on_device(expr, se_scope=S), se_scope=T) == {expr, S} + auto inner = GetOnDeviceProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.se_scope, on_device_attrs->is_fixed || inner.is_fixed}; + } else { + return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->is_fixed}; + } + } + return {}; +} + +OnDeviceProps GetOnDeviceProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetOnDeviceProps(call_node); + } + return {}; +} + +Function FunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope) { + return WithAttrs(std::move(function), {{tvm::attr::kParamSEScopes, std::move(param_se_scopes)}, + {tvm::attr::kResultSEScope, std::move(result_se_scope)}}); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); + +Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope) { + if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(), + [](const SEScope& se_scope) { return se_scope->IsFullyUnconstrained(); }) && + result_se_scope->IsFullyUnconstrained()) { + // Nothing to annotate. + return function; + } + return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope)); +} + +SEScope GetFunctionResultSEScope(const FunctionNode* function_node) { + auto opt_se_scope = function_node->GetAttr(tvm::attr::kResultSEScope); + return opt_se_scope.value_or(SEScope::FullyUnconstrained()); +} + +SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { + ICHECK_LT(i, function_node->params.size()) + << "param index " << i << " out of range for function of arity " + << function_node->params.size(); + auto opt_array = function_node->GetAttr>(tvm::attr::kParamSEScopes); + if (!opt_array) { + // No annotation. + return SEScope::FullyUnconstrained(); + } + ICHECK_EQ(opt_array.value().size(), function_node->params.size()) + << "annotation parameters do not match function arity"; + return opt_array.value()[i]; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h new file mode 100644 index 000000000000..a7b6cb7cf52a --- /dev/null +++ b/src/relay/op/memory/on_device.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/on_device.h + * \brief Helpers for working with the "on_device" 'annotation' call. + */ +#ifndef TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ +#define TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "on_device" operator. */ +const Op& OnDeviceOp(); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed. + * + * See \p OnDeviceAttrs for an overview. + */ +Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed if the + * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention. + * This means we will NOT wrap if: + * - \p se_scope is full unconstrained, which signals there are no device annotations + * already in play. + * - \p expr is an operator or primitive function literal. These are device polymorphic. + * - \p expr is a non-primitive function literal. The device is captured by the + * "result_se_scope" attribute on the function itself. + * - \p expr is a global var. The device is on the function attributes the global is bound to. + * - \p expr is a local var. The device is tracked by the device aware visitors for us. + * - \p expr is a constructor. These are device polymorphic. + * + */ +Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + SEScope se_scope = SEScope::FullyUnconstrained(); + bool is_fixed = false; + + OnDeviceProps() = default; + + OnDeviceProps(Expr body, SEScope se_scope, bool isFixed) + : body(std::move(body)), se_scope(std::move(se_scope)), is_fixed(isFixed) {} +}; + +/*! + * \brief Returns the body expression, \p SEScope, and is_fixed field for \p call_node if it + * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained + * \p SEScope, and false. + */ +OnDeviceProps GetOnDeviceProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, \p SEScope, and is_fixed field for \p expr if it is an + * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope, + * and \p false. + */ +OnDeviceProps GetOnDeviceProps(const Expr& expr); + +/*! + * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns + * \p expr directly. + */ +inline Expr IgnoreOnDevice(const Expr& expr) { + OnDeviceProps props = GetOnDeviceProps(expr); + return props.body.defined() ? props.body : expr; +} + +/*! + * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through + * any "on_device" annotations. + */ +template +const NodeType* AsIgnoringOnDevice(const Expr& expr) { + const auto* node = expr.as(); + if (node != nullptr) { + return node; + } + OnDeviceProps props = GetOnDeviceProps(expr); + if (!props.body.defined()) { + return nullptr; + } + return props.body.as(); +} + +/*! + * \brief Returns \p function annotated with "param_se_scopes" and "result_se_scope" + * attributes capturing parameter and result \p SEScopes respectively. + */ +Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); + +/*! + * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and + * result \p SEScopes are unconstrained. + */ +Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, + SEScope result_se_scope); + +/*! + * \brief Returns the \p SEScope for the resut of \p function_node, or the unconstrained + * \p SEScope if function does not have the "result_se_scope" annotation. + */ +SEScope GetFunctionResultSEScope(const FunctionNode* function_node); + +/*! + * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or + * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation. + */ +SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 38c3305d3194..e3d5a821c58e 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -38,41 +38,51 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) if (maybe_mod) { for (const auto& pair : maybe_mod.value()->functions) { if (const auto* function_node = pair.second.as()) { - DLDeviceType device_type = GetFunctionResultDeviceType(function_node); - if (device_type != kInvalidDeviceType) { - global_var_device_types_.emplace(pair.first, device_type); + SEScope se_scope = GetFunctionResultSEScope(function_node); + if (!se_scope->IsFullyUnconstrained()) { + global_var_se_scopes_.emplace(pair.first, se_scope); } } } } } -DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { - auto props = GetOnDeviceProps(expr); +SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { + OnDeviceProps props = GetOnDeviceProps(expr); if (props.body.defined() && props.is_fixed) { - return props.device_type; + return props.se_scope; } else if (const auto* var_node = expr.as()) { // Lookup variable binding. - auto itr = var_device_types_.find(GetRef(var_node)); - if (itr != var_device_types_.end()) { + auto itr = var_se_scopes_.find(GetRef(var_node)); + if (itr != var_se_scopes_.end()) { return itr->second; } - // else: fallthrough to unknown + // else: fallthrough to unconstrained } else if (const auto* global_var_node = expr.as()) { // Lookup global variable. - auto itr = global_var_device_types_.find(GetRef(global_var_node)); - if (itr != global_var_device_types_.end()) { + auto itr = global_var_se_scopes_.find(GetRef(global_var_node)); + if (itr != global_var_se_scopes_.end()) { return itr->second; } - // else: fallthrough to unknown + // else: fallthrough to unconstrained + } else if (const auto* function_node = expr.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + if (!expr_se_scopes_.empty()) { + // Use the currently in-scope device type. + return expr_se_scopes_.back(); + } + // else: fallthrough to unconstrained + } else { + return GetFunctionResultSEScope(function_node); + } } else { - if (!expr_device_types_.empty()) { + if (!expr_se_scopes_.empty()) { // Use the currently in-scope device type. - return expr_device_types_.back(); + return expr_se_scopes_.back(); } - // else: fallthrough to unknown + // else: fallthrough to unconstrained } - return kInvalidDeviceType; + return SEScope::FullyUnconstrained(); } void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } @@ -82,34 +92,34 @@ void LexicalOnDeviceMixin::ExitFunctionBody() { --function_nesting_; } -void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { - if (device_type == kInvalidDeviceType) { +void LexicalOnDeviceMixin::PushSEScope(const SEScope& se_scope) { + if (se_scope->IsFullyUnconstrained()) { return; } - expr_device_types_.emplace_back(device_type); + expr_se_scopes_.emplace_back(se_scope); } -void LexicalOnDeviceMixin::PopDeviceType() { - if (expr_device_types_.empty()) { +void LexicalOnDeviceMixin::PopSEScope() { + if (expr_se_scopes_.empty()) { return; } - expr_device_types_.pop_back(); + expr_se_scopes_.pop_back(); } -void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { - if (device_type == kInvalidDeviceType) { +void LexicalOnDeviceMixin::PushBoundVar(Var var, const SEScope& se_scope) { + if (se_scope->IsFullyUnconstrained()) { return; } - ICHECK(var_device_types_.find(var) == var_device_types_.end()); - var_device_types_.emplace(std::move(var), device_type); + ICHECK(var_se_scopes_.find(var) == var_se_scopes_.end()); + var_se_scopes_.emplace(std::move(var), se_scope); } void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { - auto itr = var_device_types_.find(var); - if (itr == var_device_types_.end()) { + auto itr = var_se_scopes_.find(var); + if (itr == var_se_scopes_.end()) { return; } - var_device_types_.erase(itr); + var_se_scopes_.erase(itr); } // TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing @@ -122,17 +132,17 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -147,7 +157,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec). - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -164,13 +174,13 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { } void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { - auto props = GetOnDeviceProps(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed) { // Entering lexical scope of fixed "on_device" call. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopDeviceType(); + PopSEScope(); } else { DeviceAwareVisitExpr_(call_node); } @@ -208,17 +218,17 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); Expr result = DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -235,7 +245,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); expr = inner_let_node->body; @@ -255,14 +265,14 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { } Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { - auto props = GetOnDeviceProps(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed) { // Entering lexical scope of fixed "on_device" call. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopDeviceType(); - return MaybeOnDevice(expr, props.device_type, props.is_fixed); + PopSEScope(); + return MaybeOnDevice(expr, props.se_scope, props.is_fixed); } else { return DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 3f4c5c24481e..8cdf0db74ebd 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -35,13 +35,14 @@ #include #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { namespace transform { /*! - * \brief Helper class for expression transformers which need to keep track of the device + * \brief Helper class for expression transformers which need to keep track of the \p SEScope * holding the results of expressions. This is recovered from function attributes and "on_device" * CallNodes added by the PlanDevices pass. * @@ -52,11 +53,11 @@ class LexicalOnDeviceMixin { explicit LexicalOnDeviceMixin(const Optional& maybe_mod); /*! - * \brief Returns the device type on which the result of \p expr should/will be stored, assuming - * Push/Pop DeviceType/BoundVar have been correctly called. May return \p kInvalidDeviceType if - * the device planning pass has not been run. + * \brief Returns the \p SEScope on which the result of \p expr should/will be stored, assuming + * {Push,Pop}{SEScope,BoundVar} have been correctly called. May return the unconstrained + * \p SEScope if the device planning pass has not been run. */ - DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + SEScope GetSEScope(const Expr& expr) const; /*! \brief Indicate a function body is being entered. */ void EnterFunctionBody(); @@ -64,19 +65,19 @@ class LexicalOnDeviceMixin { /*! \brief Indicate a function body has been processed. */ void ExitFunctionBody(); - /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ - void PushDeviceType(DLDeviceType device_type); + /*! \brief Push an \p SEScope onto the lexical SEScope stack. Ignore if unconstrained. */ + void PushSEScope(const SEScope& se_scope); - /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ - void PopDeviceType(); + /*! \brief Pop an \p SEScope from the lexical SEScope stack. Ignore if stack is empty. */ + void PopSEScope(); - /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + /*! \brief Remember that \p var will be stored at \p se_scope. Ignore if unconstrained. * * CAUTION: Despite the name we don't support re-entering the same function body. */ - void PushBoundVar(Var var, DLDeviceType device_type); + void PushBoundVar(Var var, const SEScope& se_scope); - /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + /*! \brief Remove the binding for \p var to its \p SEScope. Ignore if var is not bound. */ void PopBoundVar(const Var& var); /*! @@ -92,36 +93,36 @@ class LexicalOnDeviceMixin { int function_nesting_ = 0; /*! - * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. - * When visiting an expression other than a variable we can assume the expression's result is to - * be stored on device_type_.back(). + * \brief The stack of lexically enclosing "on_device" \p SEScopes, from outermost to + * innermost. When visiting an expression other than a variable we can assume the expression's + * result is to be stored on \p expr_se_scopes.back(). */ - std::vector expr_device_types_; + std::vector expr_se_scopes_; /*! - * \brief A map from in-scope local variables to their device types. We may assume the variable is - * only ever bound to a value stored on this device at runtime. + * \brief A map from in-scope local variables to their \p SEScopes. We may assume the variable is + * only ever bound to a value stored on this \p SEScope at runtime. * * Note: We're playing it safe and keying by object refs here just in case the Relay expression * being rewritten has no module or other global to keep it alive. */ - std::unordered_map - var_device_types_; + std::unordered_map var_se_scopes_; /*! - * \brief A map from global variables to their device types, ie the "result_device_type" of the - * function they are bound to in the module we are working on. We calculate this explicitly so - * that we don't neeed to hold on to any module, which is often in the process of being rewritten. + * \brief A map from global variables to their \p SEScopes, ie the "result_se_scope" of the + * function they are bound to in the module we are working on. We calculate and store this + * explicitly so that we don't need to hold on to any module, which is often in the process of + * being rewritten. */ - std::unordered_map - global_var_device_types_; + std::unordered_map + global_var_se_scopes_; }; template class DeviceAwareExprFunctor; /*! - * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * \brief ExprFunctor which tracks \p SEScopes. We only support 'visitor' style implementation * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without * any memoization. */ @@ -142,17 +143,17 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushDeviceType(GetFunctionResultDeviceType(function_node)); + PushSEScope(GetFunctionResultSEScope(function_node)); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopDeviceType(); + PopSEScope(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -167,7 +168,7 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -185,20 +186,20 @@ class DeviceAwareExprFunctor : public ExprFunctor : public ExprFunctor& maybe_mod) @@ -255,7 +256,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { void VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ @@ -269,7 +270,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the * let-bound variable will be tracked automatically. Default implementation just visits var and * value. */ @@ -288,7 +289,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PostVisitLetBlock_(const LetNode* let_node); }; -/*! \brief ExprMutator which tracks devices. */ +/*! \brief ExprMutator which tracks \p SEScopes. */ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { public: explicit DeviceAwareExprMutator(const Optional& maybe_mod) @@ -299,7 +300,7 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { Expr VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ @@ -313,7 +314,7 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the * let-bound variable will be tracked automatically. Default implementation just visits var and * value. */ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index b9fa0494d3b5..667379d7a9a0 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -30,6 +30,7 @@ #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" +#include "../op/memory/on_device.h" namespace tvm { namespace relay { @@ -37,11 +38,6 @@ namespace transform { namespace { -// Ye olde boost hash mixer. -constexpr size_t mix(size_t h1, size_t h2) { - return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); -} - /*! * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather * than the original "device_copy" operator. @@ -66,63 +62,46 @@ DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { } // namespace -// The following hash and equality helpers give each free first-order domain pointer its own -// distinct identity. - -size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { - if (domain->is_free()) { - // Give each free first-order domain its own identity. - return static_cast(reinterpret_cast(domain.get())); - } else { - size_t h = domain->args_and_result_.size(); - h = mix(h, std::hash()(static_cast(domain->device_type_))); - for (const auto& sub_domain_ptr : domain->args_and_result_) { - h = mix(h, DeviceDomainHash()(sub_domain_ptr)); - } - return h; - } +DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) { + host_domain_ = MakeFirstOrderDomain(config_->host_se_scope); } -bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { - if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { - // Mismatched arities are never equal. - // (Though we'll never ask to do such a comparison explicitly, the hash map - // may do so implicitly due to hash collisions.) - return false; - } - if (lhs->is_free() && rhs->is_free()) { - // Compare first-order free domains by their address. - return lhs.get() == rhs.get(); - } - if (lhs->args_and_result_.empty()) { - // Compare first-order domains by their device type -- free vs bound will compare as false. - return lhs->device_type_ == rhs->device_type_; - } else { - // Compare higher-order domains pointwise. - for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { - if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { - return false; - } +DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const SEScope& se_scope) { + if (se_scope->IsFullyConstrained()) { + auto itr = fully_constrained_se_scope_to_domain_.find(se_scope); + if (itr != fully_constrained_se_scope_to_domain_.end()) { + return itr->second; } - return true; + DeviceDomainPtr domain = std::make_shared(se_scope); + fully_constrained_se_scope_to_domain_.emplace(se_scope, domain); + return domain; + } else { + return std::make_shared(se_scope); } } -/* static */ -DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const SEScope& se_scope) { if (const auto* func_type_node = type.as()) { std::vector args_and_result; args_and_result.reserve(func_type_node->arg_types.size() + 1); for (const auto& arg_type : func_type_node->arg_types) { - args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + args_and_result.emplace_back(MakeDomain(arg_type, SEScope::FullyUnconstrained())); } - args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, se_scope)); return std::make_shared(std::move(args_and_result)); } else { - return std::make_shared(device_type); + return MakeFirstOrderDomain(se_scope); } } +DeviceDomainPtr DeviceDomains::ForSEScope(const Type& type, const SEScope& non_canonical_se_scope) { + // Generally se_scope will have come from an annotation so resolve it to ensure we have + // its canonical representation. + SEScope se_scope = config_->CanonicalSEScope(non_canonical_se_scope); + ICHECK(!se_scope->IsFullyUnconstrained()); + return MakeDomain(type, se_scope); +} + DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { DeviceDomainPtr root = domain; while (true) { @@ -145,56 +124,82 @@ DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { return root; } -DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - // TODO(mbs): Proper diagnostics. +DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (lhs == rhs) { + return lhs; + } ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) << "Device domains:" << std::endl << ToString(lhs) << std::endl << "and" << std::endl << ToString(rhs) << std::endl << "do not have the same kind and can't be unified."; - if (rhs->is_free()) { - return lhs; - } else if (lhs->is_free()) { - return rhs; - } else if (lhs->args_and_result_.empty()) { - // Must have consistent device types for first order domains. - if (lhs->device_type_ != rhs->device_type_) { - // TODO(mbs): Proper diagnostics. - std::ostringstream os; - os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; - throw Error(os.str()); + if (lhs->args_and_result_.empty()) { + // Directly compare first-order. + if (rhs->se_scope_->IsFullyUnconstrained()) { + return lhs; } - return lhs; + if (lhs->se_scope_->IsFullyUnconstrained()) { + return rhs; + } + Optional joined_se_scope = SEScope::Join(lhs->se_scope_, rhs->se_scope_); + if (!joined_se_scope) { + return nullptr; + } + return MakeFirstOrderDomain(config_->CanonicalSEScope(joined_se_scope.value())); } else { // Recurse for higher-order. std::vector args_and_result; args_and_result.reserve(lhs->args_and_result_.size()); for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { - args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + DeviceDomainPtr joined_domain = + UnifyOrNull(lhs->args_and_result_[i], rhs->args_and_result_[i]); + if (joined_domain == nullptr) { + return nullptr; + } + args_and_result.emplace_back(std::move(joined_domain)); } - return MakeDomain(std::move(args_and_result)); + return MakeHigherOrderDomain(std::move(args_and_result)); } } -DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { +DeviceDomainPtr DeviceDomains::UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + ICHECK_NOTNULL(lhs); + ICHECK_NOTNULL(rhs); lhs = Lookup(lhs); rhs = Lookup(rhs); - auto joined_domain = Join(lhs, rhs); - if (!DeviceDomainEqual()(lhs, joined_domain)) { + DeviceDomainPtr joined_domain = JoinOrNull(lhs, rhs); + if (joined_domain == nullptr) { + return nullptr; + } + if (lhs != joined_domain) { domain_to_equiv_.emplace(lhs, joined_domain); } - if (!DeviceDomainEqual()(rhs, joined_domain)) { + if (rhs != joined_domain) { domain_to_equiv_.emplace(rhs, joined_domain); } return joined_domain; } -void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { - if (!lhs->is_higher_order() && rhs->is_higher_order()) { - Collapse(lhs, rhs); +bool DeviceDomains::CollapseOrFalse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + ICHECK(!first_order_domain->is_higher_order()); + ICHECK(higher_order_domain->is_higher_order()); + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + if (UnifyOrNull(higher_order_domain->function_param(i), first_order_domain) == nullptr) { + return false; + } + } + return UnifyOrNull(higher_order_domain->function_result(), first_order_domain) != nullptr; +} + +bool DeviceDomains::UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, + const DeviceDomainPtr& rhs_maybe_higher_order) { + ICHECK(!lhs_first_order->is_higher_order()); + if (rhs_maybe_higher_order->is_higher_order()) { + return CollapseOrFalse(lhs_first_order, rhs_maybe_higher_order); } else { - Unify(lhs, rhs); + return UnifyOrNull(lhs_first_order, rhs_maybe_higher_order) != nullptr; } } @@ -216,49 +221,49 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { } std::vector args_and_result; - auto on_device_props = GetOnDeviceProps(call.get()); - auto device_copy_props = GetDeviceCopyProps(call.get()); + OnDeviceProps on_device_props = GetOnDeviceProps(call.get()); + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); if (!device_copy_props.body.defined()) { + // Special case for the TIR-ified version of "device_copy". device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); } if (on_device_props.body.defined()) { - // on_device(expr, device_type=, is_fixed=false) + // on_device(expr, se_scope=, is_fixed=false) // on_device : fn():?x? // - // on_device(expr, device_type=, is_fixed=true) + // on_device(expr, se_scope=, is_fixed=true) // on_device: fn(): args_and_result.emplace_back( - ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); if (on_device_props.is_fixed) { args_and_result.emplace_back(args_and_result.front()); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } } else if (device_copy_props.body.defined()) { - // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy(expr, src_se_scope=, dst_se_scope=) // device_copy: fn(): args_and_result.emplace_back( - ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + ForSEScope(device_copy_props.body->checked_type(), device_copy_props.src_se_scope)); args_and_result.emplace_back( - ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + ForSEScope(device_copy_props.body->checked_type(), device_copy_props.dst_se_scope)); } else if (call->op == alloc_storage_op) { ICHECK_EQ(call->args.size(), 2U); - // alloc_storage(size, alignment, device_type=) + // alloc_storage(size, alignment, se_scope=) // alloc_storage: fn(, ): const auto* attrs = call->attrs.as(); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back( - ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(ForSEScope(call->checked_type(), attrs->se_scope)); } else if (call->op == alloc_tensor_op) { ICHECK_EQ(call->args.size(), 3U); // alloc_tensor(storage, offset, shape) // alloc_tensor: fn(?x?, , ):?x? auto free_domain = Free(call->checked_type()); args_and_result.emplace_back(free_domain); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); args_and_result.emplace_back(free_domain); } else if (call->op == shape_func_op) { ICHECK_EQ(call->args.size(), 3U); @@ -268,15 +273,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(Free(call->args[0]->checked_type())); // TODO(mbs): I think this should be on the cpu only when is_input = [false], but // what do we do when we have multiple arguments with different is_input values? - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); + args_and_result.emplace_back(host_domain_); } else if (call->op == shape_of_op) { ICHECK_EQ(call->args.size(), 1U); // shape_of(tensor) // shape_of: fn(?x?): args_and_result.emplace_back(Free(call->args[0]->checked_type())); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); } else if (call->op == invoke_tvm_op) { ICHECK_EQ(call->args.size(), 3U); // invoke_tvm_op(op, inputs, outputs) @@ -293,13 +298,13 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // reshape_tensor: fn(?x?, ):?x? auto free_domain = Free(call->checked_type()); args_and_result.emplace_back(free_domain); - args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(host_domain_); args_and_result.emplace_back(free_domain); } else if (call->op->IsInstance()) { // (arg1, ..., argn) // : fn(?x?, ..., ?x?):?x? // (all args and result must be first-order). - auto free_domain = Free(arb_); + auto free_domain = MakeFirstOrderDomain(SEScope::FullyUnconstrained()); for (size_t i = 0; i < call->args.size(); ++i) { args_and_result.emplace_back(free_domain); } @@ -315,8 +320,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); auto result_domain = Free(func_type_node->ret_type); // first-order for (const auto& arg_type : func_type_node->arg_types) { - auto param_domain = Free(arg_type); // possibly higher-order - UnifyCollapsed(result_domain, param_domain); // collapse if required + auto param_domain = Free(arg_type); // possibly higher-order + bool success = UnifyCollapsedOrFalse(result_domain, param_domain); // collapse if required + ICHECK(success); args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); @@ -328,7 +334,7 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // because the device planner runs before and after lowering. return DomainFor(call->op); } - auto domain = MakeDomain(std::move(args_and_result)); + auto domain = MakeHigherOrderDomain(std::move(args_and_result)); call_to_callee_domain_.emplace(call.get(), domain); return domain; } @@ -336,111 +342,104 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { auto lhs_domain = DomainFor(lhs); auto rhs_domain = DomainFor(rhs); - try { - Unify(lhs_domain, rhs_domain); - } catch (const Error& e) { + if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + LOG(FATAL) << "Incompatible SEScopes for expressions:" << std::endl << PrettyPrint(lhs) << std::endl - << "with device:" << std::endl + << "with scope:" << std::endl << ToString(lhs_domain) << "and:" << std::endl << PrettyPrint(rhs) << std::endl - << "with device:" << std::endl - << ToString(rhs_domain) << std::endl - << e.what(); + << "with scope:" << std::endl + << ToString(rhs_domain); } } void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { auto actual_domain = DomainFor(expr); - try { - Unify(actual_domain, expected_domain); - } catch (const Error& e) { + if (UnifyOrNull(actual_domain, expected_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expression:" << std::endl + LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl << PrettyPrint(expr) << std::endl - << "with actual device:" << std::endl + << "with actual scope:" << std::endl << ToString(actual_domain) << std::endl - << "and expected device:" << std::endl - << ToString(expected_domain) << std::endl - << e.what(); + << "and expected scope:" << std::endl + << ToString(expected_domain); } } -void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { - auto actual_domain = DomainFor(expr); - try { - UnifyCollapsed(actual_domain, expected_domain); - } catch (const Error& e) { +void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order, + const DeviceDomainPtr& expected_domain_maybe_higher_order) { + auto actual_domain_first_order = DomainFor(expr_first_order); + if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible devices for expression:" << std::endl - << PrettyPrint(expr) << std::endl - << "with actual device:" << std::endl - << ToString(actual_domain) << std::endl - << "and expected device:" << std::endl - << ToString(expected_domain) << std::endl - << e.what(); + LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl + << PrettyPrint(expr_first_order) << std::endl + << "with actual scope:" << std::endl + << ToString(actual_domain_first_order) << std::endl + << "and expected scope:" << std::endl + << ToString(expected_domain_maybe_higher_order); } } -bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { +bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) { domain = Lookup(domain); - if (domain->is_free()) { - return true; - } - for (const auto& sub_domain : domain->args_and_result_) { - if (AnyFree(sub_domain)) { - return true; - } - } - return false; -} - -void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, - const DeviceDomainPtr& higher_order_domain) { - for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { - Unify(higher_order_domain->function_param(i), first_order_domain); + if (domain->args_and_result_.empty()) { + // First-order. + return domain->se_scope_->IsFullyConstrained(); + } else { + // Higher-order. + return std::all_of( + domain->args_and_result_.begin(), domain->args_and_result_.end(), + [this](const DeviceDomainPtr& sub_domain) { return IsFullyConstrained(sub_domain); }); } - Unify(higher_order_domain->function_result(), first_order_domain); } -void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { - ICHECK_NE(default_device_type, kInvalidDeviceType); +void DeviceDomains::SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope) { + ICHECK(!default_se_scope->IsFullyUnconstrained()); domain = Lookup(domain); - if (domain->is_free()) { - // Will never throw since lhs is free. - Unify(domain, std::make_shared(default_device_type)); - } else if (!domain->args_and_result_.empty()) { + if (domain->args_and_result_.empty()) { + DeviceDomainPtr defaulted_domain_ptr = + UnifyOrNull(domain, MakeFirstOrderDomain(config_->CanonicalSEScope( + SEScope::Default(domain->se_scope_, default_se_scope)))); + ICHECK_NOTNULL(defaulted_domain_ptr); + } else { for (const auto& sub_domain : domain->args_and_result_) { - SetDefault(sub_domain, default_device_type); + SetDefault(sub_domain, default_se_scope); } } } -void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, - DLDeviceType default_device_type) { - if (!domain->is_higher_order()) { - SetDefault(domain, default_device_type); - return; +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, + const SEScope& default_se_scope) { + if (domain_maybe_higher_order->args_and_result_.empty()) { + SetDefault(domain_maybe_higher_order, default_se_scope); + } else { + // First set default for result domain. + SetDefault(ResultDomain(domain_maybe_higher_order), default_se_scope); + // Then use current result domain as default for everything else. + SetDefault(domain_maybe_higher_order, ResultSEScope(domain_maybe_higher_order)); } - DLDeviceType result_device_type = ResultDeviceType(domain); - if (result_device_type == kInvalidDeviceType) { - // If the function result device is still free use the given default. - result_device_type = default_device_type; +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); } - // Default any remaining free parameters to the function result device. - SetDefault(domain, result_device_type); + return domain; } std::string DeviceDomains::ToString(DeviceDomainPtr domain) { domain = Lookup(domain); std::ostringstream os; - if (domain->is_free()) { - // first-order free - os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; - } else if (domain->args_and_result_.empty()) { - // first-order bound - os << "<" << domain->device_type_ << ">"; + if (domain->args_and_result_.empty()) { + // First-order. + if (!domain->se_scope_->IsFullyConstrained()) { + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } + if (!domain->se_scope_->IsFullyUnconstrained()) { + os << domain->se_scope_; + } } else { // higher-order os << "fn("; @@ -474,14 +473,6 @@ std::string DeviceDomains::ToString() { return os.str(); } -DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { - domain = Lookup(domain); - while (!domain->args_and_result_.empty()) { - domain = Lookup(domain->args_and_result_.back()); - } - return domain; -} - } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h index a29370a0e807..f3f31e790983 100644 --- a/src/relay/transforms/device_domains.h +++ b/src/relay/transforms/device_domains.h @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -42,13 +44,14 @@ namespace transform { class DeviceDomain; using DeviceDomainPtr = std::shared_ptr; +class DeviceDomains; /*! * \brief Represents the domain over which we collect equality constraints. * * \code * D ::= ?x? -- first order, free - * | -- first order, bound + * | -- first order, bound to specific device and memory scope * | fn(D1, ..., Dn):Dr -- higher order * \endcode * @@ -56,44 +59,46 @@ using DeviceDomainPtr = std::shared_ptr; * a notion of the 'result domain' of a domain: * \code * result_domain(?x?) = ?x? - * result_domain() = + * result_domain() = * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) * \endcode */ class DeviceDomain { public: /*! - * \brief Constructs a first-order domain of \p device_type, which may be - * \p kInvalidDeviceType to indicate the domain is free. + * \brief Constructs a first-order domain for \p se_scope, which may be + * fully free (ie se_scope is unconstrained), partially free (ie se_scope has at least on + * of its target, device id or memory scopes known), or fully fixed (ie se_scope has its target, + * device id and memory scopes set). + * + * CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor. */ - explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + explicit DeviceDomain(SEScope se_scope) : se_scope_(std::move(se_scope)) {} /*! * \brief Constructs a higher-order domain, where \p args_and_result contain the * function argument and result domains in order. + * + * CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor. */ explicit DeviceDomain(std::vector args_and_result) - : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + : se_scope_(SEScope::FullyUnconstrained()), args_and_result_(std::move(args_and_result)) {} - /*! \brief Returns true if domain is first-order and free. */ - bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } - - /*! \brief Returns true if domain is higher-order. */ bool is_higher_order() const { return !args_and_result_.empty(); } - DLDeviceType first_order_device_type() const { - ICHECK(args_and_result_.empty()); - return device_type_; + SEScope first_order_se_scope() const { + ICHECK(args_and_result_.empty()) << "expecting domain to be first-order"; + return se_scope_; } size_t function_arity() const { - ICHECK(!args_and_result_.empty()); + ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order"; return args_and_result_.size() - 1UL; } DeviceDomainPtr function_param(size_t i) const { - ICHECK(!args_and_result_.empty()); - ICHECK_LT(i + 1, args_and_result_.size()); + ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order"; + ICHECK_LT(i + 1, args_and_result_.size()) << "parameter index is out of range"; return args_and_result_[i]; } @@ -104,11 +109,12 @@ class DeviceDomain { private: /*! - * \brief If this is a function domain then always kInvalidDevice. Otherwise will be - * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is - * bound. + * \brief If this is a function domain then always fully unconstrained. Otherwise will be + * fully unconstrained (the domain is still completely free), partially constrained + * (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and + * \p memory_scope are still unconstrained), or fully constrained (everything is known). */ - const DLDeviceType device_type_; + const SEScope se_scope_; /*! * \brief If this is a function domain then the sub-domains for each of the function's @@ -116,81 +122,92 @@ class DeviceDomain { */ const std::vector args_and_result_; - friend struct DeviceDomainHash; - friend struct DeviceDomainEqual; friend class DeviceDomains; }; -// The following hash and equality helpers give each free first-order domain pointer its own -// distinct identity. -struct DeviceDomainHash { - size_t operator()(const DeviceDomainPtr& domain) const; -}; - -struct DeviceDomainEqual { - public: - bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; -}; - /*! * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation - * built up by calls to \p Unify. + * built up by calls to \p UnifyOrNull. */ class DeviceDomains { public: - DeviceDomains() = default; + explicit DeviceDomains(CompilationConfig config); + + const CompilationConfig& config() const { return config_; } /*! - * \brief Returns a domain appropriate for \p type who's result domain is bound - * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain - * will be free. + * \brief Returns the domain representing \p se_scope. If \p se_scope is fully constrained + * then the domain will be unique that \p se_scope. */ - static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + DeviceDomainPtr MakeFirstOrderDomain(const SEScope& se_scope); /*! * \brief Returns a higher-order domain with \p args_and_results. */ - static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + DeviceDomainPtr MakeHigherOrderDomain(std::vector arg_and_results) { return std::make_shared(std::move(arg_and_results)); } - /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ - static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { - ICHECK_NE(device_type, kInvalidDeviceType); - return MakeDomain(type, device_type); - } + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound to \p se_scope. + * If \p type is a function then all parameter domains will be completely free. It is valid for + * \p se_scope to be fully unconstrained. + */ + DeviceDomainPtr MakeDomain(const Type& type, const SEScope& se_scope); + + /*! + * \brief Returns a domain with the given result appropriate \p non_canonical_se_scope, + * which cannot be fully unconstrained. We first canonicalize the scope to unsure it has + * a target and is unique. + */ + DeviceDomainPtr ForSEScope(const Type& type, const SEScope& non_canonical_se_scope); /*! \brief Returns a free domain appropriate for \p type. */ - static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, SEScope::FullyUnconstrained()); } /*! \brief Returns the domain representing the equivalence class containing \p domain. */ DeviceDomainPtr Lookup(DeviceDomainPtr domain); /*! - * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. - * - * Throws \p Error on failure. + * \brief Returns the most constrained domain which agrees with both \p lhs and \p rhs. Returns + * null if no such domain exists, ie some first-order component of \p lhs is constrained + * differently than the corresponding component of \p rhs. */ - DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + DeviceDomainPtr JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); /*! - * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p - * rhs disagree on bound device type. - * - * Throws \p Error on failure. + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Returns null if + * \p lhs and \p rhs are not unifiable. */ // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but // given we have refs to functions I'm prepared to be surprised. - DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + DeviceDomainPtr UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Returns false if any unification fails. + */ + bool CollapseOrFalse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); /*! - * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, - * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as - * \p Unify. + * \brief Unifies \p lhs_first_order and \p rhs_maybe_higher_order. If \p rhs_maybe_higher_order + * is indeed higher-order, require all of its arguments and result to unify with + * \p lhs_first_order. Otherwise same as \p Unify. Returns false if unification is not possible. * - * Throws \p Error on failure. + * In an expression such as: + * \code + * (fn(...) {...}, ...).0 + * \endcode + * we need to force all the devices of the inner function to be the same as the device for the + * overall tuple since the device domain does not understand tuples. Similarly for references + * and ADTs. */ - void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + bool UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, + const DeviceDomainPtr& rhs_maybe_higher_order); /*! \brief Returns true if a domain is known for \p expr. */ bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } @@ -204,7 +221,8 @@ class DeviceDomains { * DomainFor(call->op). * * This special handling is needed: - * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given + * devices. * - To handle some special ops which constrain devices to the CPU. * - To allow the same primitive to be called on different devices at different call sites. * Since each call to the op can have a different domain we index the ops by the call expression @@ -212,11 +230,17 @@ class DeviceDomains { */ DeviceDomainPtr DomainForCallee(const Call& call); - /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + /*! + * \brief Unifies the domains for expressions \p lhs and \p rhs. + * + * Aborts if unification fails. + */ void UnifyExprExact(const Expr& lhs, const Expr& rhs); /*! * \brief Unifies the domain for \p expr with \p expected_domain. + * + * Aborts if unification fails. */ void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); @@ -224,37 +248,25 @@ class DeviceDomains { * \brief Unifies the domain for \p expr with \p expected_domain. * If \p expected_domain is higher-order but \p expr is first-order, require all arguments * and the result of \p expected_domain to have the same domain as for \p expr. - */ - void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); - - /*! \brief Returns true if \p domain contains any free sub-domains. */ - bool AnyFree(DeviceDomainPtr domain); - - /* - * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. - * This can be used to handle functions within tuples, references and ADTs since we don't - * attempt to track anything beyond 'the device' for expressions of those first-order types. * - * Throws \p Error on failure. + * Aborts if unification fails. */ - void Collapse(const DeviceDomainPtr& first_order_domain, - const DeviceDomainPtr& higher_order_domain); + void UnifyExprCollapsed(const Expr& expr_first_order, + const DeviceDomainPtr& expected_domain_maybe_higher_order); - /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ - void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + /*! \brief Returns true if \p domain is fully constrainted. */ + bool IsFullyConstrained(DeviceDomainPtr domain); + + /*! \brief Force all \p SEScopes in \p domain to default to \p default_se_scope. */ + void SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope); /*! - * \brief If \p domain is higher-order and its result domain is free, force it to - * \p default_device_type. Then force any remaining free domains to the result domain - * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + * \brief If \p domain is higher-order default it's result domain to \p default_se_scope. + * Then force all remaining \p SEScopes to the result domain (freshly defaulted or original). + * If \p domain is first-order same as \p SetDefault. */ - void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); - - /*! \brief Returns one-line description of \p domain for debugging. */ - std::string ToString(DeviceDomainPtr domain); - - /*! \brief Returns description of entire system of constraints for debugging */ - std::string ToString(); + void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, + const SEScope& default_se_scope); /*! * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). @@ -262,13 +274,19 @@ class DeviceDomains { DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); /*! - * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain - * comment). + * \brief Returns the result \p SEScope (possibly unconstrained) for \p domain + * (see defn in DeviceDomain comment). */ - DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { - return ResultDomain(domain)->first_order_device_type(); + SEScope ResultSEScope(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_se_scope(); } + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + private: /*! \brief Intrinsics we need to handle specially. */ const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); @@ -277,12 +295,14 @@ class DeviceDomains { const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); const Op& shape_func_op = Op::Get("vm.shape_func"); const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); - /*! \brief The CPU device type for special operators such as dynamic shape functions. */ - const DLDeviceType cpu_device_type_ = kDLCPU; - /*! \brief Placeholder for any first-order type. */ - Type arb_ = TupleType(); - /*! \brief The domain for first-order expressions on the CPU. */ - DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + CompilationConfig config_; + + /*! + * \brief The domain for first-order expressions of non-tensor type, such as shapes and + * buffer dimensions. Generally this will be a CPU. + */ + DeviceDomainPtr host_domain_; /*! \brief Maps expressions to their domains as determined during analysis. */ std::unordered_map expr_to_domain_; @@ -293,8 +313,19 @@ class DeviceDomains { std::unordered_map call_to_callee_domain_; /*! \brief Maps device domains to their equivalent domains as determined during unification. */ - std::unordered_map - domain_to_equiv_; + std::unordered_map domain_to_equiv_; + + /*! + * \brief Maps fully constrained \p SEScopes to their corresponding domains. By sharing those + * domains we can ensure: + * + * \code + * domain0 != domain1 && domain0 fully constrained && domain1 fully constrained + * ==> domain0 and domain1 are incompatible + * \endcode + */ + std::unordered_map + fully_constrained_se_scope_to_domain_; }; } // namespace transform diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 83429a9e616f..d6ab566a336e 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -19,22 +19,22 @@ /*! * \file src/relay/transforms/device_planner.cc - * \brief Determines a unique device to hold the result of every Relay sub-expression. + * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. - * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the - * specific target associated with D (this is recovered independently via a TargetMap), and we - * do not track the storage scope within D (this is yet to be implemented). + * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device + * of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's + * code is compiled with a specific \p Target. * * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', * see below. * * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: - * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and - * 'dst_dev_type' device type, which constrain the argument and context of the call + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_se_scope' and + * 'dst_se_scope' \p SEScopes, which constrain the argument and context of the call * respectively. It is ok if source and destination devices are the same, such no-op copies * will be removed after accounting for the device preference. - * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', which * constrains the argument of the call, but (usually, see below) leaves the context * unconstrained. These are called 'annotations' in the rest of the code, have no operational * significance by themselves, but may trigger the insertion of a new "device_copy". @@ -63,15 +63,16 @@ * ------- * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see * below) to all other Relay sub-expressions. (For idempotence we also respect any existing - * "param_device_types" and "result_device_type" function attributes we introduce below.) + * "param_se_scopes" and "result_se_scope" function attributes we introduce below.) * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the * same device. However each call site can use a different device. In other words primitives are - * 'device polymorphic' since we compile and execute them for each required device. + * 'device polymorphic' since we compile and execute them for each required device. ADT constructors + * are similarly polymorphic. * * For most Relay expressions the device for the overall expression is the same as the device - * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple - * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall \p if, * and so on. * * Some special ops (or 'dialects') are handled: @@ -91,18 +92,18 @@ * * Phase 2 * ------- - * After flowing constraints we apply some defaulting heuristics (using a global default device) + * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) * to fix the device for any as-yet unconstrained sub-expressions. * - Unconstrained function result devices default to the global default device. * - Unconstrained function parameters devices default to the device for the function result. * - Unconstrained let-bound expression devices default to the device for the overall let. - * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to - * the global default device. Worth a design doc with motivating examples I think. + * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the + * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself. * * Phase 3 * ------- * Finally, the result of this analysis is reified into the result as: - * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * - Additional "param_se_scopes" (an \p Array) and "result_se_scope" (an \p SEScope) * attributes for every function (both top-level and local). These describe the devices for * the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the @@ -124,14 +125,15 @@ * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion * to ANF must respect the lexical scoping convention: * \code - * f(on_device(g(h(a, b), c), device_type=CPU)) + * f(on_device(g(h(a, b), c), se_scope=CPU)) * ==> - * let %x0 = on_device(h(a, b), device_type=CPU) - * let %x1 = on_device(g(%x0), device-type=CPU) - * f(on_device(%x1, device_type=CPU)) + * let %x0 = on_device(h(a, b), se_scope=CPU) + * let %x1 = on_device(g(%x0), se_scope=CPU) + * f(on_device(%x1, se_scope=CPU)) * \endcode * * This pass can be run before FuseOps it can use device-specific fusion rules. + * TODO(mbs): We also need to support running after FuseOps. * * 'Stored on' vs 'Executes on' * ---------------------------- @@ -147,7 +149,7 @@ * pass, but we'd like to fold that into device planning here to ensure everything is consistent. * * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay - * expression (eg an if expression) on one device even though the tensor data resides on + * expression (eg an \p if expression) on one device even though the tensor data resides on * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just * compile the function body for the function's result device. @@ -157,7 +159,7 @@ * minimize cross-device calls by moving device copies out of functions. E.g.: * \code * def @f() { // execute on CPU - * let x = on_device(...GPU computation..., device_type=GPU); + * let x = on_device(...GPU computation..., se_scope=GPU); * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) * } * def @main() { @@ -189,7 +191,7 @@ * \code * let f = fn(x, y) { ... } * let g = fn(f, z) { f(z, z) } - * g(f, on_device(..., device_type=CPU)) + * g(f, on_device(..., se_scope=CPU)) * \endcode * the parameters \p x and \p y will be on the CPU. * @@ -226,28 +228,16 @@ * `-- Mark's stamp of completeness :-) * * TODO(mbs): - * * Though on_device is the identity for all types we can't wrap it around functions/constructors - * taking type args (or at least not without changing type_infer.cc to see through them). - * This is not currently handled generally. * * Proper diagnostics for unification failure using spans. - * * Make sure the pass is idempotent even after FuseOps etc. - * * Support application of constructors properly. Are they device polymorphic? - * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. * * Support running the pass post FuseOps (so need to understand primitive functions, both * outlines and lined) and post the VM transforms (probably need to support more intrinsic * forms?). * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default * device for primitives vs the default device for the rest of Relay. - * * We'll probably need some support for partial 'device polymorphism' for functions once we - * incorporate targets and memory scopes into the domain. For example it's ok for the function - * body to be executed on different device ids provided they have the same target and memory - * scope. - * * Might be simpler to just let every type have a device annotation rather than work in - * a separate domain? + * * We may want some 'device polymorphism' for Relay functions. Eg it's ok for the function + * to be called with params/result on different (virtual) device ids provided the target and + * memory scopes are consistent. * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. - * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls - * in tuples at the top level of function bodies or main expression, irrespective of the - * "on_device" body. What's up with that? */ #include @@ -267,6 +257,7 @@ #include "../op/annotation/annotation.h" #include "../op/memory/device_copy.h" +#include "../op/memory/on_device.h" #include "./device_domains.h" namespace tvm { @@ -283,11 +274,11 @@ namespace { * \brief Rewrites "on_device" calls to handle some special cases. * * \code - * let %x = on_device(e, device_type=d) - * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * let %x = on_device(e, se_scope=d) + * ==> let %x = on_device(e, se_scope=d, is_fixed=True) * - * fn(%x) { on_device(e, device_type=d) } - * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * fn(%x) { on_device(e, se_scope=d) } + * ==> fn(%x) { on_device(e, se_scope=d, is_fixed=True) * * on_device(e).0 * ==> on_device(e.0) @@ -303,12 +294,12 @@ class RewriteOnDevices : public ExprMutator { // TODO(mbs): Avoid copy. Expr tuple_get_item = TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); - auto props = GetOnDeviceProps(tuple); + OnDeviceProps props = GetOnDeviceProps(tuple); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "wrapping tuple get item:" << std::endl + VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl - << "with \"on_device\" for device " << props.device_type; - return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + << "with \"on_device\" for SEScope " << props.se_scope; + return OnDevice(tuple_get_item, props.se_scope, /*is_fixed=*/false); } else { return tuple_get_item; } @@ -320,12 +311,12 @@ class RewriteOnDevices : public ExprMutator { while (const auto* inner_let_node = expr.as()) { Expr inner_let = GetRef(inner_let_node); Expr value = VisitExpr(inner_let_node->value); - auto props = GetOnDeviceProps(value); + OnDeviceProps props = GetOnDeviceProps(value); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "revising let-bound expression of let:" << std::endl + VLOG(2) << "revising let-bound expression of let:" << std::endl << PrettyPrint(expr) << std::endl - << "to be fixed to device " << props.device_type; - value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + << "to be fixed to SEScope " << props.se_scope; + value = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); expr = inner_let_node->body; @@ -341,12 +332,12 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const FunctionNode* function_node) final { Expr body = VisitExpr(function_node->body); - auto props = GetOnDeviceProps(body); + OnDeviceProps props = GetOnDeviceProps(body); if (props.body.defined() && !props.is_fixed) { - VLOG(1) << "revising body of function:" << std::endl + VLOG(2) << "revising body of function:" << std::endl << PrettyPrint(GetRef(function_node)) << std::endl - << "to be fixed to device " << props.device_type; - body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + << "to be fixed to SEScope " << props.se_scope; + body = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } // TODO(mbs): Avoid copy return Function(function_node->params, body, function_node->ret_type, @@ -363,12 +354,12 @@ class RewriteOnDevices : public ExprMutator { * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. * * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, - * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * from \code on_device(%x, se_scope=d) \endcode we know \p %x must be on device \p d, and thus * so must \p %y. * * Constraints can flow in interesting ways. E.g. in: * \code - * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %f = fn(%x, %y) { add(%x, on_device(%y, se_scope=d)) } * let %g = fn(%f, %x, %y) { %f(%x, %y) } * %g(%f, %a, %b) * \endcode @@ -376,8 +367,8 @@ class RewriteOnDevices : public ExprMutator { */ class DeviceAnalyzer : public ExprVisitor { public: - explicit DeviceAnalyzer(IRModule mod) - : mod_(std::move(mod)), domains_(std::make_unique()) {} + DeviceAnalyzer(IRModule mod, CompilationConfig config) + : mod_(std::move(mod)), domains_(std::make_unique(std::move(config))) {} /*! * \brief Returns the expression-to-device-domain map for all expressions in all the global @@ -387,7 +378,7 @@ class DeviceAnalyzer : public ExprVisitor { std::unique_ptr Analyze() { VLOG_CONTEXT << "DeviceAnalyzer"; for (const auto& pair : mod_->functions) { - VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; domains_->UnifyExprExact(pair.first, pair.second); VisitExpr(pair.second); } @@ -413,9 +404,9 @@ class DeviceAnalyzer : public ExprVisitor { } args_and_result_domains.emplace_back(domains_->DomainFor(call)); auto implied_domain = - DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); // higher-order - VLOG(1) << "initial call function domain:" << std::endl + VLOG(2) << "initial call function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and implied domain:" << std::endl << domains_->ToString(implied_domain) << std::endl @@ -423,21 +414,18 @@ class DeviceAnalyzer : public ExprVisitor { << PrettyPrint(call); // The above must match. - try { - domains_->Unify(func_domain, implied_domain); // higher-order - } catch (const Error& e) { + if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:" << std::endl << PrettyPrint(call) << std::endl - << "with function devices:" << std::endl + << "with function scopes:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and implied call devices:" << std::endl - << domains_->ToString(implied_domain) << std::endl - << e.what(); + << "and implied call scopes:" << std::endl + << domains_->ToString(implied_domain); } - VLOG(1) << "final call function domain:" << std::endl + VLOG(2) << "final call function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "for call:" << std::endl << PrettyPrint(call); @@ -477,7 +465,7 @@ class DeviceAnalyzer : public ExprVisitor { domains_->UnifyExprExact(function_node->body, func_domain->function_result()); // may be higher-order - VLOG(1) << "initial function domain:" << std::endl + VLOG(2) << "initial function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and function body domain:" << std::endl << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl @@ -492,37 +480,33 @@ class DeviceAnalyzer : public ExprVisitor { VisitExpr(function_node->params[i]); } - // If the function already has device attributes then we can further constrain the + // If the function already has SEScope attributes then we can further constrain the // function's domain to match them. - if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + if (!GetFunctionResultSEScope(function_node)->IsFullyUnconstrained()) { std::vector args_and_result; for (size_t i = 0; i < function_node->params.size(); ++i) { - args_and_result.emplace_back( - domains_->ForDeviceType(function_node->params[i]->checked_type(), - GetFunctionParamDeviceType(function_node, i))); + args_and_result.emplace_back(domains_->ForSEScope( + function_node->params[i]->checked_type(), GetFunctionParamSEScope(function_node, i))); } - args_and_result.emplace_back(domains_->ForDeviceType( - function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); - auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); - try { - domains_->Unify(func_domain, annotation_domain); // higher-order - } catch (const Error& e) { + args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), + GetFunctionResultSEScope(function_node))); + auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); + if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. LOG(FATAL) - << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:" << std::endl << PrettyPrint(function) << std::endl - << "with function devices:" << std::endl + << "with function scopes:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and annotation devices:" << std::endl - << domains_->ToString(annotation_domain) << std::endl - << e.what(); + << "and annotation scopes:" << std::endl + << domains_->ToString(annotation_domain); } } VisitExpr(function_node->body); - VLOG(1) << "final function domain:" << std::endl + VLOG(2) << "final function domain:" << std::endl << domains_->ToString(func_domain) << std::endl << "and function body domain:" << std::endl << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl @@ -652,7 +636,7 @@ class DeviceAnalyzer : public ExprVisitor { * \code * def @main(%x, %y, %z) { * let %a = add(%x, %y); - * multiply(%a, on_device(%z, device_type=d)) + * multiply(%a, on_device(%z, se_scope=d)) * \endcode * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, * and the device for the function result, are still 'free'. The global 'default' device type @@ -664,17 +648,14 @@ class DeviceAnalyzer : public ExprVisitor { */ class DeviceDefaulter : public ExprVisitor { public: - DeviceDefaulter(IRModule mod, std::unique_ptr domains, - DLDeviceType default_device_type) - : mod_(std::move(mod)), - domains_(std::move(domains)), - default_device_type_(default_device_type) {} + DeviceDefaulter(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; - VLOG(0) << "using default device type " << default_device_type_; + VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope; for (const auto& pair : mod_->functions) { - VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; VisitExpr(pair.second); } return std::move(domains_); @@ -689,10 +670,11 @@ class DeviceDefaulter : public ExprVisitor { auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - if (domains_->AnyFree(func_domain)) { - VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); - domains_->SetResultDefaultThenParams(func_domain, default_device_type_); - VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + if (!domains_->IsFullyConstrained(func_domain)) { + VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, + domains_->config()->default_primitive_se_scope); + VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); } VisitExpr(function_node->body); } @@ -701,13 +683,14 @@ class DeviceDefaulter : public ExprVisitor { auto call = GetRef(call_node); auto func_domain = domains_->DomainForCallee(call); // higher-order ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); - if (domains_->AnyFree(func_domain)) { + if (!domains_->IsFullyConstrained(func_domain)) { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be // defaulted. - VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); - domains_->SetResultDefaultThenParams(func_domain, default_device_type_); - VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, + domains_->config()->default_primitive_se_scope); + VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); } return ExprVisitor::VisitExpr_(call_node); } @@ -719,13 +702,13 @@ class DeviceDefaulter : public ExprVisitor { Let let = Downcast(expr); // If the let-var device is still free force it to match the overall let. auto let_domain = domains_->DomainFor(let); // may be higher-order - DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); - ICHECK_NE(let_device_type, kInvalidDeviceType); + SEScope let_se_scope = domains_->ResultSEScope(let_domain); + ICHECK(!let_se_scope->IsFullyUnconstrained()); auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order - if (domains_->AnyFree(let_var_domain)) { - VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); - domains_->SetDefault(let_var_domain, let_device_type); - VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + if (!domains_->IsFullyConstrained(let_var_domain)) { + VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_se_scope); + VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); } VisitExpr(let->var); VisitExpr(let->value); @@ -738,8 +721,6 @@ class DeviceDefaulter : public ExprVisitor { IRModule mod_; /*! \brief The domains for all expressions. */ std::unique_ptr domains_; - /*! \brief The default device type. */ - DLDeviceType default_device_type_; }; /****** @@ -754,7 +735,7 @@ class DeviceDefaulter : public ExprVisitor { * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard * any existing "device_copy" CallNodes which are no-ops. * - * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * - Functions are given "param_se_scopes" and "result_se_scope" attributes to capture * the device type for its parameters and result. * * - Additional "device_copy" CallNodes are inserted wherever there's a transition between @@ -773,10 +754,10 @@ class DeviceDefaulter : public ExprVisitor { * * For example, we'll end up with programs that look like: * \code - * def @main(%x, %y, param_device_types=[...], result_device_type=...) { - * let %a = on_device(..., device_type=..., is_fixed=True) - * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), - * src_device_type=..., dst_device_type=...)) + * def @main(%x, %y, param_se_scopes=[...], result_se_scope=...) { + * let %a = on_device(..., se_scope=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., se_scope=..., is_fixed=True), + * src_se_scope=..., dst_se_scope=...)) * } * \endcode */ @@ -789,7 +770,7 @@ class DeviceCapturer : public ExprMutator { VLOG_CONTEXT << "CaptureDevices"; IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); for (const auto& pair : mod_->functions) { - VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + VLOG(2) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; result->Add(pair.first, Downcast(Mutate(pair.second))); } return result; @@ -816,39 +797,39 @@ class DeviceCapturer : public ExprMutator { auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order - VLOG(1) << "capturing function:" << std::endl + VLOG(2) << "capturing function:" << std::endl << PrettyPrint(function) << std::endl << "with domain:" << std::endl << domains_->ToString(func_domain); // Gather the parameter and result device types for the function attributes. ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); - ICHECK_NE(result_device_type, kInvalidDeviceType); - Array param_device_types; - param_device_types.reserve(function_node->params.size()); + SEScope result_se_scope = domains_->ResultSEScope(func_domain); + ICHECK(!result_se_scope->IsFullyUnconstrained()); + Array param_se_scopes; + param_se_scopes.reserve(function_node->params.size()); for (size_t i = 0; i < function_node->params.size(); ++i) { - DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); - ICHECK_NE(param_device_type, kInvalidDeviceType); - param_device_types.push_back(param_device_type); + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + ICHECK(!param_se_scope->IsFullyUnconstrained()); + param_se_scopes.push_back(param_se_scope); } // Rewrite the body. Note that the body may have begun with an "on_device" so // be prepared to insert a "device_copy". Expr body = VisitChild( - /*lexical_device_type=*/result_device_type, - /*expected_device_type=*/result_device_type, - /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + /*lexical_se_scope=*/result_se_scope, + /*expected_se_scope=*/result_se_scope, + /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); // TODO(mbs): Avoid copy Function func = Function(function_node->params, body, function_node->ret_type, function_node->type_params, function_node->attrs, function_node->span); - return FunctionOnDevice(func, param_device_types, result_device_type); + return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope)); } Expr VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); - DLDeviceType call_device_type = GetDeviceType(call); + SEScope call_se_scope = GetSEScope(call); auto on_device_props = GetOnDeviceProps(call_node); if (on_device_props.body.defined()) { @@ -857,31 +838,36 @@ class DeviceCapturer : public ExprMutator { return VisitExpr(on_device_props.body); } - auto device_copy_props = GetDeviceCopyProps(call_node); + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); if (device_copy_props.body.defined()) { - DLDeviceType src_device_type = device_copy_props.src_dev_type; - ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); - if (call_device_type == src_device_type) { + SEScope src_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.src_se_scope); + SEScope dst_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.dst_se_scope); + ICHECK_EQ(call_se_scope, dst_se_scope); + if (src_se_scope == dst_se_scope) { // We can pinch out existing "device_copy" CallNodes if their source and destinations // match. return VisitExpr(device_copy_props.body); + } else { + return VisitChild(/*lexical_se_scope=*/dst_se_scope, + /*expected_se_scope=*/dst_se_scope, + /*child_se_scope=*/src_se_scope, device_copy_props.body); } - // else: handle as for any other call. } + // Generic call. auto func_domain = domains_->DomainForCallee(call); // higher-order - VLOG(1) << "considering call:" << std::endl + VLOG(2) << "considering call:" << std::endl << PrettyPrint(call) << std::endl - << "on device " << call_device_type << " with function domain:" << std::endl + << "in scope " << call_se_scope << " with function domain:" << std::endl << domains_->ToString(func_domain); - DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); - ICHECK_NE(result_device_type, kInvalidDeviceType); + SEScope result_se_scope = domains_->ResultSEScope(func_domain); + ICHECK(!result_se_scope->IsFullyUnconstrained()); // The callee is on the current device. Expr op = VisitChild( - /*lexical_device_type=*/call_device_type, - /*expected_device_type=*/call_device_type, - /*child_device_type=*/result_device_type, call_node->op); + /*lexical_se_scope=*/call_se_scope, + /*expected_se_scope=*/call_se_scope, + /*child_se_scope=*/result_se_scope, call_node->op); // Each argument can be on the device for the corresponding function parameter. However if // any of those differ from the overall call device then wrap them in an "on_device" to @@ -890,13 +876,13 @@ class DeviceCapturer : public ExprMutator { args.reserve(call_node->args.size()); ICHECK_EQ(func_domain->function_arity(), call->args.size()); for (size_t i = 0; i < call_node->args.size(); ++i) { - DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); - ICHECK_NE(param_device_type, kInvalidDeviceType) + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + ICHECK(!param_se_scope->IsFullyUnconstrained()) << "for parameter " << i << " for call:" << std::endl << PrettyPrint(call); - args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, - /*expected_device_type=*/param_device_type, - /*child_device_type=*/GetDeviceType(call_node->args[i]), + args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope, + /*expected_se_scope=*/param_se_scope, + /*child_se_scope=*/GetSEScope(call_node->args[i]), call_node->args[i])); } // TODO(mbs): Avoid copy @@ -907,27 +893,27 @@ class DeviceCapturer : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { Expr expr = GetRef(let_node); // Iterate through chained lets, provided they all agree on their device type. - DLDeviceType let_device_type = GetDeviceType(expr); + SEScope let_se_scope = GetSEScope(expr); std::vector> bindings; while (const auto* inner_let_node = expr.as()) { Expr inner_let = GetRef(inner_let_node); - if (GetDeviceType(inner_let) != let_device_type) { + if (GetSEScope(inner_let) != let_se_scope) { // We have a device transition which needs to be handled. break; } // The let-bound value can be on a different device than the overall let. However if those // devices don't agree wrap the let-bound value in an "on_device" to help downstream // transforms track devices lexically. - Expr value = VisitChild(/*lexical_device_type=*/let_device_type, - /*expected_device_type=*/GetDeviceType(inner_let_node->var), - /*child_device_type=*/GetDeviceType(inner_let_node->value), - inner_let_node->value); + Expr value = + VisitChild(/*lexical_se_scope=*/let_se_scope, + /*expected_se_scope=*/GetSEScope(inner_let_node->var), + /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); expr = inner_let_node->body; } - Expr body = VisitChild(/*lexical_device_type=*/let_device_type, - /*expected_device_type=*/let_device_type, - /*child_device_type=*/GetDeviceType(expr), expr); + Expr body = VisitChild(/*lexical_se_scope=*/let_se_scope, + /*expected_se_scope=*/let_se_scope, + /*child_se_scope=*/GetSEScope(expr), expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, /*span=*/std::get<2>(*itr)); @@ -987,69 +973,69 @@ class DeviceCapturer : public ExprMutator { return Match(data, std::move(clauses), match_node->complete, match_node->span); } - DLDeviceType GetDeviceType(const Expr& expr) { + SEScope GetSEScope(const Expr& expr) { // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. - auto props = GetOnDeviceProps(expr); + OnDeviceProps props = GetOnDeviceProps(expr); Expr true_expr = props.body.defined() ? props.body : expr; ICHECK(domains_->contains(true_expr)); - // If expr is higher order we'll return only the result domain's device type. - DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); - ICHECK_NE(device_type, kInvalidDeviceType) - << "no device type was determined for expression:" << std::endl + // If expr is higher order we'll return only the result domain's SEScope. + SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr)); + ICHECK(!se_scope->IsFullyUnconstrained()) + << "no SEScope was determined for expression:" << std::endl << PrettyPrint(true_expr); - return device_type; + return std::move(se_scope); } /*! - * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type - * (as required by the expression context the \p child is in) and the \p lexical_device_type + * \brief Reconcile the \p child_se_scope for \p child with both the \p expected_se_scope + * (as required by the expression context the \p child is in) and the \p lexical_se_scope * (as a downstream transform would infer based only on lexically enclosing "on_device" - * CallNodes and function attributes.) Generally \p lexical_device_type and \p - * expected_device_type are the same by definition, but may differ in arguments to functions + * CallNodes and function attributes.) Generally \p lexical_se_scope and \p + * expected_se_scope are the same by definition, but may differ in arguments to functions * and let-bound expressions. * - * If \p child_device_type differs from \p expected_device_type, wrap it as: + * If \p child_se_scope differs from \p expected_se_scope, wrap it as: * \code - * device_copy(on_device(child', device_type=child_device_type), - * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * device_copy(on_device(child', se_scope=child_se_scope), + * src_dev_type=child_se_scope, dst_dev_type=expected_se_scope) * \endcode * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the * child. * - * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * If \p expected_se_scope differs from \p lexical_se_scope, then (also) wrap * the expression as: * \code - * on_device(..., device_type=expected_device_type) + * on_device(..., se_scope=expected_se_scope) * \endcode * * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped * by a "device_copy", even though those copies will generally all be to the same destination * device. */ - Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, - DLDeviceType child_device_type, const Expr& child) { - ICHECK_NE(lexical_device_type, kInvalidDeviceType); - ICHECK_NE(expected_device_type, kInvalidDeviceType); - if (child->IsInstance()) { - // Primitive operators don't need to be rewritten and can have a different domain for - // each call site. + Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, + const SEScope& child_se_scope, const Expr& child) { + ICHECK(!lexical_se_scope->IsFullyUnconstrained()); + ICHECK(!expected_se_scope->IsFullyUnconstrained()); + if (child->IsInstance() || child->IsInstance()) { + // Primitive operators and contructors don't need to be rewritten and can have a + // different domain at each call site. return child; } Expr result = VisitExpr(child); - if (child_device_type != expected_device_type) { - VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type - << " to device type " << expected_device_type << " for:" << std::endl + if (child_se_scope != expected_se_scope) { + VLOG(2) << "creating " << DeviceCopyOp()->name << " from scope " << child_se_scope + << " to scope " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); // Also wrap the child in an "on_device" so downstream transforms can track devices // lexically. - result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); - result = DeviceCopy(result, child_device_type, expected_device_type); + result = MaybeOnDevice(result, child_se_scope, /*is_fixed=*/true); + result = DeviceCopy(result, child_se_scope, expected_se_scope); } - if (expected_device_type != lexical_device_type) { - VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + if (expected_se_scope != lexical_se_scope) { + VLOG(2) << "creating " << OnDeviceOp()->name << " for scope " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); - result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + result = MaybeOnDevice(result, expected_se_scope, /*is_fixed=*/true); } return result; } @@ -1059,9 +1045,9 @@ class DeviceCapturer : public ExprMutator { * is expected to be on the same device as the \p parent. */ Expr VisitChild(const Expr& parent, const Expr& child) { - DLDeviceType expected_device_type = GetDeviceType(parent); - DLDeviceType child_device_type = GetDeviceType(child); - return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + SEScope expected_se_scope = GetSEScope(parent); + SEScope child_se_scope = GetSEScope(child); + return VisitChild(expected_se_scope, expected_se_scope, child_se_scope, child); } /*! \brief Module we are rewriting, so we can lookup global variables. */ @@ -1079,21 +1065,22 @@ tvm::transform::Pass Rewrite() { } /*! \brief Run the remaining phases. */ -tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { +tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { return tvm::transform::CreateModulePass( - [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + [config = std::move(config)](IRModule mod, + tvm::transform::PassContext pass_cnxt) -> IRModule { // Collect the system of constraints for every sub-expression using existing "on_device" // and "device_copy" calls. - std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); - VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + std::unique_ptr domains = DeviceAnalyzer(mod, config).Analyze(); + VLOG(3) << "Domains after analysis:" << std::endl << domains->ToString(); // Choose sensible default devices for every sub-expression if otherwise unconstrained // by existing "on_device" or "device_copy" calls. - domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); - VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + domains = DeviceDefaulter(mod, std::move(domains)).Default(); + VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString(); // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture - // the above map, and attach additional "param_device_types" and "result_device_type" + // the above map, and attach additional "param_se_scopes" and "result_se_scope" // attributes to all function definitions. return DeviceCapturer(mod, std::move(domains)).Capture(); }, @@ -1107,17 +1094,14 @@ tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { *******/ // This function is declared in the public . -TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { +tvm::transform::Pass PlanDevices(CompilationConfig config) { std::vector passes; passes.emplace_back(Rewrite()); - passes.emplace_back(PlanDevicesCore(default_device_type)); - return tvm::transform::Sequential(std::move(passes), "PlanDevices"); + passes.emplace_back(PlanDevicesCore(std::move(config))); + return tvm::transform::Sequential(passes, "PlanDevices"); } -TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") - .set_body_typed([](const Device& default_device) { - return PlanDevices(default_device.device_type); - }); +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices").set_body_typed(PlanDevices); } // namespace transform } // namespace relay diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index c48a9b30967c..05ee9d5ad592 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -31,8 +31,7 @@ #include #include -#include "../op/annotation/annotation.h" -#include "./device_aware_visitors.h" +#include "../op/memory/on_device.h" #include "./pattern_utils.h" namespace tvm { @@ -42,7 +41,7 @@ namespace transform { namespace { /*! * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" - * annotation CallNode (which serves only to associate a device to the constant and has no + * annotation CallNode (which serves only to associate an \p SEScope to the constant and has no * operational effect). */ bool IsSimpleConstant(const Expr& expr) { @@ -87,19 +86,19 @@ class ConstantFolder : public MixedModeMutator { // the variable. // // We need to retain any "on_device" annotation so that downstream 'device aware' - // passes can still retrieve the device for the constant in its new position(s). Eg: - // def @f(..., result_device_type=D) { - // let %x = on_device(... something we eval to a constant..., device_type=E) + // passes can still retrieve the \p SEScope for the constant in its new position(s). Eg: + // def @f(..., result_se_scope=D) { + // let %x = on_device(... something we eval to a constant..., se_scope=E) // @f(..., %x, ...) // } - // Here the default device is D, whereas the argument %x to @f is on E (and @f expects + // Here the default scope is D, whereas the argument %x to @f is on E (and @f expects // that). No on_device annotation is required in the call according to the convention used // by the device-aware visitors. // // However once we've inlined the constant we need to insert an on_device, again to // respect the convention used by the device-aware visitors. - // def @f(..., result_device_type=D) { - // @f(..., on_device(...the constant..., device_type=E), ...) + // def @f(..., result_se_scope=D) { + // @f(..., on_device(...the constant..., se_scope=E), ...) // } VLOG(1) << "Replacing let-binding for " << op->var->name_hint() << " with constant:" << std::endl @@ -215,8 +214,8 @@ class ConstantFolder : public MixedModeMutator { Expr result = tuple_node->fields[tuple_get_item_node->index]; OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { - // (on_device((x, y, z), device_type=D).1 ==> on_device(y, device_type=D) - return MaybeOnDevice(result, props.device_type, props.is_fixed); + // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D) + return MaybeOnDevice(result, props.se_scope, props.is_fixed); } else { return result; } @@ -248,19 +247,15 @@ class ConstantFolder : public MixedModeMutator { VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr); // We'll invoke the interpreter using the generic CPU device and target. Technically there's - // no guarantee the results we bitwise equal what we'd get on the true device, however to + // no guarantee the results will be bitwise equal what we'd get on the true device, however to // support cross-compilation we don't want to assume the true device is available. - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; - Target target = Target("llvm"); // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); - Expr result = - ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); + Expr result = ObjectToExpr( + Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); return result; } @@ -288,17 +283,14 @@ class ConstantFolder : public MixedModeMutator { } // Get the constant shape - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; runtime::NDArray value; DLDataType cdtype = DataType::Int(32); if (ishape.empty()) { - value = runtime::NDArray::Empty({}, cdtype, dev); + value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); } else { ICHECK_NE(ishape.size(), 0); std::vector cshape = {static_cast(ishape.size())}; - value = runtime::NDArray::Empty(cshape, cdtype, dev); + value = runtime::NDArray::Empty(cshape, cdtype, eval_cpu_dev_); auto* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { @@ -313,7 +305,7 @@ class ConstantFolder : public MixedModeMutator { Constant shape = Downcast(ObjectToExpr(value)); if (shape->data.Shape().empty() && GetScalarFromConstant(shape) == 0) { - auto ndarray = runtime::NDArray::Empty({}, cdtype, dev); + auto ndarray = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); shape = Constant(ndarray); } @@ -342,12 +334,9 @@ class ConstantFolder : public MixedModeMutator { } // Get the constant size - Device dev; - dev.device_type = kDLCPU; - dev.device_id = 0; runtime::NDArray value; DLDataType cdtype = DataType::Int(32); - value = runtime::NDArray::Empty({}, cdtype, dev); + value = runtime::NDArray::Empty({}, cdtype, eval_cpu_dev_); auto* data = static_cast(value->data); if (ishape.empty()) { *data = 0; @@ -390,6 +379,13 @@ class ConstantFolder : public MixedModeMutator { // Module IRModule module_; + // The kDLCPU device assumed to be available to the compiler. Used only when evaluating + // sub-expressions. + Device eval_cpu_dev_{kDLCPU, /*device_id=*/0}; + // The target for the above device assumed to be available to the compiler. Used only when + // evaluating sub-expressions. + Target eval_cpu_target_{"llvm"}; + // Cache the following ops for equivalence checking in this pass. const Op& device_copy_op_; const Op& shape_of_op_; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index a328eaa82aa2..acea12fb8560 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -59,17 +59,6 @@ using namespace tvm::runtime; namespace tvm { namespace relay { -inline Constant MakeConstant(const std::vector& value) { - return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); -} - -inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, - Array assert_shape, DLDeviceType offset_device_type) { - auto offset = - OnDevice(MakeConstantScalar(DataType::Int(64), 0), offset_device_type, /*is_fixed=*/true); - return AllocTensor(storage, offset, shape, dtype, assert_shape); -} - // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { if (const FunctionNode* func = expr.as()) { @@ -88,11 +77,13 @@ bool IsReshapeOnly(const Expr& expr) { class DialectRewriter : public transform::DeviceAwareExprMutator { public: - DialectRewriter(IRModule mod, const Target& target_host) - : transform::DeviceAwareExprMutator(std::move(mod)), target_host_(target_host) {} + DialectRewriter(IRModule mod, SEScope host_se_scope) + : transform::DeviceAwareExprMutator(std::move(mod)), + host_se_scope_(std::move(host_se_scope)) {} Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } + private: Expr VisitExpr_(const TupleNode* tn) final { LetList& scope = scopes_.back(); Array new_fields; @@ -131,7 +122,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const CallNode* cn) final { Call call = GetRef(cn); - DLDeviceType device_type = GetInScopeDeviceType(call); + SEScope se_scope = GetSEScope(call); if (IsPrimitive(cn)) { // Because we are in ANF we do not need to visit the arguments. // TODO(mbs): But does so anyway... @@ -163,26 +154,21 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } const DeviceCopyAttrs* copy_attr = attr.as(); CHECK(copy_attr); - return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type); + return DeviceCopy(new_args[0], copy_attr->src_se_scope, copy_attr->dst_se_scope); } else if (IsDynamic(ret_type)) { Function func = Downcast(cn->op); - // TODO(mbs): Device id is always zero. - Device device{device_type, /*device_id=*/0}; - return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, device); + return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type, se_scope); } else { // Handle the static case Array outs; for (size_t i = 0; i < out_types.size(); ++i) { - DLDeviceType device_type = GetInScopeDeviceType(GetRef(cn)); - // TODO(mbs): Device id is always zero. - Device device{device_type, /*device_id=*/0}; - auto out = MakeStaticAllocation(&scope, out_types[i], device, std::to_string(i)); + auto out = MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i)); outs.push_back(out); } Tuple output(outs); // TODO(mbs): Capture device in attributes. Expr invoke = InvokeTVMOp(cn->op, ins, output); - scope.Push(OnDevice(invoke, device_type, /*is_fixed=*/true)); + scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(output->fields.begin(), output->fields.end())); } @@ -191,11 +177,26 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } } - private: + /*! Returns the Relay Constant representing the 1d tensor with \p value. + * + * CAUTION: Make sure the constant ends up on the correct device. + */ + inline Constant MakeConstant(const std::vector& value) { + return MakeConstantTensor(DataType::Int(64), {static_cast(value.size())}, value); + } + + /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ + inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, + Array assert_shape) { + Expr offset = OnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, + /*is_fixed=*/true); + return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, + assert_shape); + } + // Insert a device copy node. - Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return Mutate(relay::DeviceCopy(inp, static_cast(src_dev), - static_cast(dst_dev))); + Expr DeviceCopy(const Expr& inp, SEScope src_se_scope, SEScope dst_se_scope) { + return Mutate(relay::DeviceCopy(inp, std::move(src_se_scope), std::move(dst_se_scope))); } // Check if a call invokes a primitive function. @@ -250,28 +251,28 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } // Allocate a tensor with a statically known shape. - Var MakeStaticAllocation(LetList* scope, const TensorType& type, Device dev, String name_hint) { + Var MakeStaticAllocation(LetList* scope, const TensorType& type, const SEScope& se_scope, + String name_hint) { std::vector int_shape; for (auto it : type->shape) { const auto* imm = it.as(); CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = OnDevice(MakeConstant(int_shape), cpu_device_.device_type, /*is_fixed=*/true); - Expr size = OnDevice(ComputeStorage(type), cpu_device_.device_type, /*is_fixed=*/true); + Expr shape = OnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); + Expr size = OnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = OnDevice(AllocStorage(size, alignment, dev, type->dtype), dev.device_type, + Expr value = OnDevice(AllocStorage(size, alignment, se_scope, type->dtype), se_scope, /*is_fixed=*/true); auto sto = scope->Push(var, value); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice( - AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape, cpu_device_.device_type), - dev.device_type, /*is_fixed=*/true); + auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), + se_scope, /*is_fixed=*/true); Var tensor_var("tensor_" + name_hint, Type(nullptr)); return scope->Push(tensor_var, tensor); } @@ -283,7 +284,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { tec::TECompiler compiler; - tec::CCacheKey key(func, target_host_); + tec::CCacheKey key(func, host_se_scope_->target); auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; @@ -311,10 +312,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { is_inputs.push_back(0); } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device - DLDeviceType device_type = GetInScopeDeviceType(arg); - if (device_type != cpu_device_.device_type) { - new_arg = OnDevice(DeviceCopy(new_arg, device_type, cpu_device_.device_type), - cpu_device_.device_type, /*is_fixed=*/true); + SEScope arg_se_scope = GetSEScope(arg); + if (arg_se_scope != host_se_scope_) { + new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_, + /*is_fixed=*/true); } Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); @@ -332,14 +333,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto tt = TensorType(out->shape, out->dtype); // Put shape func on CPU. This also ensures that everything between // shape_of and shape_func are on CPU. - auto alloc = OnDevice(MakeStaticAllocation(scope, tt, cpu_device_, std::to_string(i)), - cpu_device_.device_type, /*is_fixed=*/true); + auto alloc = OnDevice(MakeStaticAllocation(scope, tt, host_se_scope_, std::to_string(i)), + host_se_scope_, /*is_fixed=*/true); Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); alloc = scope->Push(shape_func_out_var, alloc); out_shapes.push_back(alloc); } auto shape_call = OnDevice(ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs), - cpu_device_.device_type, /*is_fixed=*/true); + host_se_scope_, /*is_fixed=*/true); Var shape_func_var("shape_func", Type(nullptr)); scope->Push(shape_func_var, shape_call); return out_shapes; @@ -348,19 +349,19 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { // Generate the code for invoking a TVM op with a dynamic shape. Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins, const std::vector& new_args, const std::vector& out_types, - const Type& ret_type, Device dev) { + const Type& ret_type, const SEScope& se_scope) { auto out_shapes = EmitShapeFunc(scope, func, new_args); std::vector storages; CHECK_EQ(out_shapes.size(), out_types.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), cpu_device_.device_type, + auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, /*is_fixed=*/true); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = OnDevice(AllocStorage(size, alignment, dev, out_type->dtype), dev.device_type, + auto val = OnDevice(AllocStorage(size, alignment, se_scope, out_type->dtype), se_scope, /*is_fixed=*/true); storages.push_back(scope->Push(sto_var, val)); } @@ -370,16 +371,15 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape, - cpu_device_.device_type), - dev.device_type, /*is_fixed=*/true); + auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), + se_scope, /*is_fixed=*/true); Var out_var("out_" + std::to_string(i), Type(nullptr)); outs.push_back(scope->Push(out_var, alloc)); } Tuple tuple_outs(outs); auto call = InvokeTVMOp(func, ins, tuple_outs); - auto invoke = OnDevice(call, dev.device_type, /*is_fixed=*/true); + auto invoke = OnDevice(call, se_scope, /*is_fixed=*/true); scope->Push(invoke); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); @@ -399,27 +399,24 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = OnDevice(MakeConstant(shape), cpu_device_.device_type, /*is_fixed=*/true); + shape_expr = OnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); } return ReshapeTensor(new_args[0], shape_expr, ret_ty->shape); } private: const Op& device_copy_op_ = Op::Get("device_copy"); + runtime::DataType compute_dtype_ = runtime::DataType::Int(64); + SEScope host_se_scope_; - Target target_host_; std::vector scopes_; - - runtime::DataType compute_dtype_ = runtime::DataType::Int(64); - Device cpu_device_{kDLCPU, 0}; }; namespace transform { -Pass ManifestAlloc(Target target_host, Map targets) { - CheckAndUpdateHostConsistency(&targets, &target_host); +Pass ManifestAlloc(SEScope host_se_scope) { return tvm::transform::CreateModulePass( - [=](IRModule mod, const PassContext& pass_ctx) { + [host_se_scope](IRModule mod, const PassContext& pass_ctx) { // We need to mutate module, therefore making a copy of it. mod.CopyOnWrite(); mod->ImportFromStd("core.rly"); @@ -429,7 +426,7 @@ Pass ManifestAlloc(Target target_host, Map targets) { for (const auto& it : glob_funcs) { if (auto* func_node = it.second.as()) { auto func = GetRef(func_node); - auto rewriter = DialectRewriter(mod, target_host); + auto rewriter = DialectRewriter(mod, host_se_scope); auto updated_func = rewriter.Rewrite(func); mod->Update(it.first, updated_func); @@ -442,11 +439,7 @@ Pass ManifestAlloc(Target target_host, Map targets) { 0, "ManifestAlloc", {}); } -TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc") - .set_body_typed([](Target target_host, Map targets) { - CheckAndUpdateHostConsistency(&targets, &target_host); - return ManifestAlloc(target_host, targets); - }); +TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc").set_body_typed(ManifestAlloc); } // namespace transform diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index fd7f0a5594c2..317ac17f83c8 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -37,6 +37,7 @@ #include "../analysis/dependency_graph.h" #include "../op/annotation/annotation.h" +#include "../op/memory/on_device.h" #include "./let_list.h" namespace tvm { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index c767770a8be8..0814e73ab73d 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -211,14 +211,14 @@ class Fill : ExprFunctor, private transform::Lexi } Expr Atomic(const Expr& e, const Var& v) { - Expr annotated_expr = MaybeOnDevice(e, GetInScopeDeviceType(e), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDevice(e, GetSEScope(e), /*is_fixed=*/true); return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; } // Bind expression `now` to var `v` if the original expression is in the include set, or if // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Expr annotated_expr = MaybeOnDevice(now, GetInScopeDeviceType(orig), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDevice(now, GetSEScope(orig), /*is_fixed=*/true); Var var = v.defined() ? v : Var(String("x"), Type()); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { @@ -229,15 +229,15 @@ class Fill : ExprFunctor, private transform::Lexi } Expr VisitExpr_(const CallNode* c, const Var& v) final { - auto props = GetOnDeviceProps(c); + OnDeviceProps props = GetOnDeviceProps(c); if (props.body.defined() && props.is_fixed) { // Keep track of expression device type for lexically enclosing sub-expressions. - PushDeviceType(props.device_type); + PushSEScope(props.se_scope); Expr body = VisitExpr(props.body, v); // We are done with this sub-expression. - PopDeviceType(); + PopSEScope(); // Preserve the "on_device" annotations. - return OnDevice(body, props.device_type, props.is_fixed); + return OnDevice(body, props.se_scope, props.is_fixed); } Expr e = GetRef(c); @@ -292,9 +292,9 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. - PushDeviceType(GetFunctionResultDeviceType(f)); + PushSEScope(GetFunctionResultSEScope(f)); for (size_t i = 0; i < f->params.size(); ++i) { - PushBoundVar(f->params[i], GetFunctionParamDeviceType(f, i)); + PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i)); } EnterFunctionBody(); ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, @@ -304,7 +304,7 @@ class Fill : ExprFunctor, private transform::Lexi for (size_t i = 0; i < f->params.size(); ++i) { PopBoundVar(f->params[i]); } - PopDeviceType(); + PopSEScope(); } if (function_nesting() == 0) { ICHECK(!v.defined()); @@ -319,7 +319,7 @@ class Fill : ExprFunctor, private transform::Lexi Expr VisitExpr_(const LetNode* l, const Var& v) final { Expr e = GetRef(l); // Keep track of bound variable device types for lexically enclosing sub-expressions. - PushBoundVar(l->var, GetInScopeDeviceType(l->value)); + PushBoundVar(l->var, GetSEScope(l->value)); VisitExpr(l->value, l->var); Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); // We are done with these sub-expressions. diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index 09b928fa1e39..f83e27d2c11d 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -119,13 +120,10 @@ Instruction::Instruction(const Instruction& instr) { this->shape_of.tensor = instr.shape_of.tensor; return; case Opcode::ReshapeTensor: - this->reshape_tensor.tensor = instr.reshape_tensor.tensor; - this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + this->reshape_tensor = instr.reshape_tensor; return; case Opcode::DeviceCopy: - this->src = instr.src; - this->src_device_type = instr.src_device_type; - this->dst_device_type = instr.dst_device_type; + this->device_copy = instr.device_copy; return; default: std::ostringstream out; @@ -225,13 +223,10 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->shape_of.tensor = instr.shape_of.tensor; return *this; case Opcode::ReshapeTensor: - this->reshape_tensor.tensor = instr.reshape_tensor.tensor; - this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + this->reshape_tensor = instr.reshape_tensor; return *this; case Opcode::DeviceCopy: - this->src = instr.src; - this->src_device_type = instr.src_device_type; - this->dst_device_type = instr.dst_device_type; + this->device_copy = instr.device_copy; return *this; default: std::ostringstream out; @@ -338,14 +333,14 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName } Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - Index device_type, RegName dst) { + Index device_index, RegName dst) { Instruction instr; instr.op = Opcode::AllocStorage; instr.dst = dst; instr.alloc_storage.allocation_size = size; instr.alloc_storage.alignment = alignment; instr.alloc_storage.dtype_hint = dtype_hint; - instr.alloc_storage.device_type = device_type; + instr.alloc_storage.device_index = device_index; return instr; } @@ -366,14 +361,14 @@ Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName return instr; } -Instruction Instruction::DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, +Instruction Instruction::DeviceCopy(RegName src, Index src_device_index, Index dst_device_index, RegName dst) { Instruction instr; instr.op = Opcode::DeviceCopy; instr.dst = dst; - instr.src = src; - instr.src_device_type = src_device_type; - instr.dst_device_type = dst_device_type; + instr.device_copy.src = src; + instr.device_copy.src_device_index = src_device_index; + instr.device_copy.dst_device_index = dst_device_index; return instr; } @@ -609,7 +604,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " " << instr.alloc_storage.alignment << " " << DLDataType2String(instr.alloc_storage.dtype_hint) << " " - << instr.alloc_storage.device_type; + << instr.alloc_storage.device_index; break; } case Opcode::ShapeOf: { @@ -622,8 +617,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::DeviceCopy: { - os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.dst_device_type << " " - << instr.src_device_type; + os << "device_copy $" << instr.dst << " $" << instr.device_copy.src << " " + << instr.device_copy.dst_device_index << " " << instr.device_copy.src_device_index; break; } default: diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 4d7ee457e1e6..4a044584dccd 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -63,6 +63,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetBytecode(); }); } else if (name == "get_constants") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); }); + } else if (name == "get_virtual_devices") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { @@ -165,13 +167,21 @@ String ShapeString(const ShapeTuple& shape_tuple, DLDataType dtype) { std::string Executable::GetConstants() const { std::ostringstream oss; - for (size_t i = 0; i < constants.size(); ++i) { const auto& constant = constants[i]; auto ndarray = Downcast(constant); - DLDeviceType device_type = static_cast(const_device_type[i]); oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) - << " on device of type " << device_type << std::endl; + << " on device index " << const_device_indexes[i] << std::endl; + } + return oss.str(); +} + +std::string Executable::GetVirtualDevices() const { + std::ostringstream oss; + for (size_t i = 0; i < virtual_devices.size(); ++i) { + const auto& device = virtual_devices[i]; + oss << "VM VirtualDevice[" << i << "]: device type " << device.device_type << " and id " + << device.device_id << std::endl; } return oss.str(); } @@ -245,6 +255,9 @@ TVMByteArray Executable::Save() { // Save header SaveHeader(&strm); + // Save virtual devices section. + SaveVirtualDevicesSection(&strm); + // Global section. SaveGlobalSection(&strm); @@ -263,6 +276,11 @@ TVMByteArray Executable::Save() { return arr; } +void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { + strm->Write(virtual_devices); + strm->Write(host_device_index); +} + void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector> globals(this->global_map.begin(), this->global_map.end()); @@ -289,8 +307,8 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { runtime::SaveDLTensor(strm, it); } - // Save the const to device mapping. - strm->Write(this->const_device_type); + // Save the const to device index mapping. + strm->Write(this->const_device_indexes); } void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { @@ -338,7 +356,7 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { VMInstructionSerializer SerializeInstruction(const Instruction& instr) { std::vector fields; // Save the opcode. - VLOG(1) << "Serializing: " << instr << std::endl; + VLOG(2) << "Serializing: " << instr << std::endl; switch (instr.op) { case Opcode::Move: { // Number of fields = 2 @@ -407,7 +425,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(dtype.code); fields.push_back(dtype.bits); fields.push_back(dtype.lanes); - fields.push_back(instr.alloc_storage.device_type); + fields.push_back(instr.alloc_storage.device_index); fields.push_back(instr.dst); break; } @@ -487,7 +505,8 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { } case Opcode::DeviceCopy: { // Number of fields = 4 - fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst}); + fields.assign({instr.device_copy.src, instr.device_copy.src_device_index, + instr.device_copy.dst_device_index, instr.dst}); break; } default: @@ -504,7 +523,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), - func.params, func.params_device_type); + func.params, func.param_device_indexes); func_format.Save(strm); // Serialize each instruction. @@ -564,6 +583,9 @@ runtime::Module Executable::Load(const std::string& code, const runtime::Module // Load header. LoadHeader(&strm); + // Virtual devices section + exec->LoadVirtualDevicesSection(&strm); + // Global section. exec->LoadGlobalSection(&strm); @@ -579,6 +601,12 @@ runtime::Module Executable::Load(const std::string& code, const runtime::Module return runtime::Module(exec); } +void Executable::LoadVirtualDevicesSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&virtual_devices), "virtual_device"); + STREAM_CHECK(strm->Read(&host_device_index), "virtual_device"); + ICHECK(host_device_index >= 0 && host_device_index < static_cast(virtual_devices.size())); +} + void Executable::LoadGlobalSection(dmlc::Stream* strm) { std::vector globals; STREAM_CHECK(strm->Read(&globals), "global"); @@ -597,14 +625,15 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; i++) { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm), "constant"); - this->constants.push_back(constant); + this->constants.emplace_back(std::move(constant)); } - // Load the const to device mapping. - std::vector const_device_type; - STREAM_CHECK(strm->Read(&const_device_type), "constant"); - ICHECK_EQ(size, const_device_type.size()); - this->const_device_type = const_device_type; + // Load the const to device index mapping. + std::vector const_device_indexes; + const_device_indexes.reserve(size); + STREAM_CHECK(strm->Read(&const_device_indexes), "constant"); + ICHECK_EQ(size, const_device_indexes.size()); + this->const_device_indexes = std::move(const_device_indexes); } void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { @@ -846,8 +875,9 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, - loaded_func.register_file_size, loaded_func.params_device_type); + VMFunction vm_func = + VMFunction(loaded_func.name, loaded_func.params, instructions, + loaded_func.register_file_size, loaded_func.param_device_indexes); auto it = this->global_map.find(loaded_func.name); ICHECK(it != this->global_map.end()); ICHECK_LE(it->second, this->global_map.size()); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index cd2d1332580b..e5afb0e4b1fc 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -101,12 +101,10 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { void VirtualMachineDebug::OpStartHook(Instruction instr) { if (prof_ && prof_.operator*().IsRunning()) { if (instr.op == Opcode::LoadConst) { - Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + Device dev = GetDevice(exec_->const_device_indexes[instr.const_index]); prof_.operator*().StartCall("VM::LoadConst", dev, {}); } else if (instr.op == Opcode::DeviceCopy) { - Device dst_dev; - dst_dev.device_type = static_cast(instr.dst_device_type); - dst_dev.device_id = 0; + Device dst_dev = GetDevice(instr.device_copy.dst_device_index); prof_.operator*().StartCall("VM::DeviceCopy", dst_dev, {}); } else if (instr.op == Opcode::ReshapeTensor) { prof_.operator*().StartCall("VM::ReshapeTensor", devices_[1], {}); @@ -124,7 +122,7 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) { } else if (instr.op == Opcode::AllocTensorReg) { auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(shape_obj).CopyTo(cpu_dev); prof_.operator*().StartCall( @@ -135,8 +133,8 @@ void VirtualMachineDebug::OpStartHook(Instruction instr) { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); std::ostringstream shape; shape << DLDataType2String(instr.alloc_storage.dtype_hint) << "[" << size << "]"; - prof_.operator*().StartCall("VM::AllocStorage", - {static_cast(instr.alloc_storage.device_type), 0}, + Device dev = GetDevice(instr.alloc_storage.device_index); + prof_.operator*().StartCall("VM::AllocStorage", dev, {{"VM::Argument Shapes", String(shape.str())}}); } else { prof_.operator*().StartCall("VM::UnknownOp", devices_[1], {}); diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h index b4a10806caaf..04a79c9b0210 100644 --- a/src/runtime/vm/serialize_utils.h +++ b/src/runtime/vm/serialize_utils.h @@ -58,19 +58,19 @@ struct VMFunctionSerializer { size_t num_instructions; /*! \brief The parameters of the VMFunction. */ std::vector params; - /*! \brief The device type of each parameter of the VMFunction. */ - std::vector params_device_type; + /*! \brief The index for the devices holding each parameter of the VMFunction. */ + std::vector param_device_indexes; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params, - const std::vector& params_device_type) + const std::vector& param_device_indexes) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), params(params), - params_device_type(params_device_type) {} + param_device_indexes(param_device_indexes) {} /*! * \brief Load the serialized function header. @@ -87,7 +87,7 @@ struct VMFunctionSerializer { // Get the number of instructions. num_instructions = static_cast(std::stoll(func_info[2])); if (!strm->Read(¶ms)) return false; - if (!strm->Read(¶ms_device_type)) return false; + if (!strm->Read(¶m_device_indexes)) return false; return true; } @@ -102,7 +102,7 @@ struct VMFunctionSerializer { func_info.push_back(std::to_string(num_instructions)); strm->Write(func_info); strm->Write(params); - strm->Write(params_device_type); + strm->Write(param_device_indexes); } }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b903f793d799..05adf1d69e8d 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -232,12 +232,11 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { const auto& param_names = vm_func.params; ICHECK_EQ(args.size() - offset, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; - ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size()) << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - offset]; - Device dev = GetDevice(device_type); + Device dev = GetDevice(vm_func.param_device_indexes[i - offset]); if (args[i].type_code() == kTVMDLTensorHandle) { // Automatically convert input DLTensors to NDArray @@ -258,13 +257,14 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { inputs_.emplace(func_name, func_args); } -inline Device VirtualMachine::GetDevice(Index device_type) const { - ICHECK_GE(devices_.size(), device_type) << "devices_ doesn't contain device:" << device_type; +inline Device VirtualMachine::GetDevice(Index device_index) const { + ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index; + return devices_[device_index]; +} - auto dev = devices_[device_type]; - ICHECK_EQ(static_cast(dev.device_type), device_type) - << "device type " << device_type << " has not been initialized in the device list."; - return dev; +inline Allocator* VirtualMachine::GetAllocator(Index device_index) const { + ICHECK_GE(allocators_.size(), device_index) << "invalid device index: " << device_index; + return allocators_[device_index]; } void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { @@ -297,7 +297,12 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - VLOG(2) << "Executing Function: " << std::endl << func; + DLOG(INFO) << "Executing Function: " << std::endl << func; + for (int i = 0; i < static_cast(devices_.size()); ++i) { + DLOG(INFO) << "Device " << i << " has device type " << devices_[i].device_type + << " and device id " << devices_[i].device_id + << (i == exec_->host_device_index ? " (using as host device)" : ""); + } InvokeGlobal(func, args); RunLoop(); @@ -383,19 +388,31 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } -void VirtualMachine::Init(const std::vector& devs, +void VirtualMachine::Init(const std::vector& physical_devices, const std::vector& alloc_types) { - ICHECK_EQ(devs.size(), alloc_types.size()); - // Cache the device - for (size_t i = 0; i < devs.size(); i++) { - auto dev_type = static_cast(devs[i].device_type); - auto alloc = MemoryManager::GetOrCreateAllocator(devs[i], alloc_types[i]); - if (devices_.size() <= dev_type) { - devices_.resize(dev_type + 1); - allocators_.resize(dev_type + 1); - } - devices_[dev_type] = devs[i]; - allocators_[dev_type] = alloc; + ICHECK_EQ(physical_devices.size(), alloc_types.size()); + + // Find a physical device to represent each virtual device the VM code requires. + // (Recall the VM instructions refer to devices by "device index" into this vector of + // virtual devices.) + const size_t num_virtual_devices = exec_->virtual_devices.size(); + devices_.reserve(num_virtual_devices); + allocators_.reserve(num_virtual_devices); + + for (size_t device_index = 0; device_index < num_virtual_devices; ++device_index) { + // We'll retain the legacy behaviour and just match by device type. + // TODO(mbs): Generalize. + DLDeviceType virtual_device_type = exec_->virtual_devices[device_index].device_type; + auto itr = std::find_if(physical_devices.begin(), physical_devices.end(), + [virtual_device_type](const Device& physical_device) { + return physical_device.device_type == virtual_device_type; + }); + CHECK(itr != physical_devices.end()) + << "Unable to find a physical device (from among the " << physical_devices.size() + << " given) to match the virtual device with device type " << virtual_device_type; + const size_t i = std::distance(physical_devices.begin(), itr); + devices_.push_back(*itr); + allocators_.push_back(MemoryManager::GetOrCreateAllocator(*itr, alloc_types[i])); } } @@ -408,7 +425,7 @@ ObjectRef VirtualMachine::ReadRegister(Index r) const { return frames_.back().re int64_t VirtualMachine::LoadScalarInt(Index r) const { int64_t result = 0; const auto& obj = ReadRegister(r); - NDArray array = Downcast(CopyTo(obj, {kDLCPU, 0})); + NDArray array = Downcast(CopyTo(obj, GetDevice(exec_->host_device_index))); switch (array->dtype.bits) { case 1: { @@ -473,7 +490,7 @@ void VirtualMachine::RunLoop() { } if (!const_pool_[instr.const_index].defined()) { - Device dev = GetDevice(exec_->const_device_type[instr.const_index]); + Device dev = GetDevice(exec_->const_device_indexes[instr.const_index]); const_pool_[instr.const_index] = CopyTo(constant_obj, dev); } WriteRegister(instr.dst, const_pool_[instr.const_index]); @@ -484,7 +501,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::LoadConsti: { - auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); + auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, GetDevice(exec_->host_device_index)); reinterpret_cast(tensor->data)[0] = instr.load_consti.val; WriteRegister(instr.dst, tensor); pc_++; @@ -544,7 +561,7 @@ void VirtualMachine::RunLoop() { auto object = ReadRegister(instr.get_tag.object); const auto& adt = Downcast(object); auto tag = adt.tag(); - auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); + auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, GetDevice(exec_->host_device_index)); reinterpret_cast(tag_tensor->data)[0] = tag; WriteRegister(instr.dst, tag_tensor); pc_++; @@ -600,7 +617,7 @@ void VirtualMachine::RunLoop() { } case Opcode::AllocTensorReg: { OpStartHook(instr); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_dev)); auto shape = ToShape(shape_tensor); @@ -637,16 +654,15 @@ void VirtualMachine::RunLoop() { OpStartHook(instr); auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; + auto storage_obj = SimpleObjAllocator().make_object(); - auto dev_type = instr.alloc_storage.device_type; - ICHECK_LT(static_cast(dev_type), allocators_.size()) - << "Memory allocator for device " << dev_type << " has not been initialized"; - auto* alloc = allocators_[dev_type]; - ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + Allocator* allocator = GetAllocator(instr.alloc_storage.device_index); + ICHECK(allocator) << "Did you forget to init the VirtualMachine with devices?"; VLOG(2) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) - << ", device_type=" << instr.alloc_storage.device_type; - storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); + << ", device_index=" << instr.alloc_storage.device_index; + + storage_obj->buffer = allocator->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); OpStopHook(); @@ -657,7 +673,8 @@ void VirtualMachine::RunLoop() { auto input = ReadRegister(instr.shape_of.tensor); NDArray input_array = Downcast(input); int ndim = input_array->ndim; - auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0}); + auto out_tensor = + NDArray::Empty({ndim}, {kDLInt, 64, 1}, GetDevice(exec_->host_device_index)); for (int i = 0; i < ndim; ++i) { reinterpret_cast(out_tensor->data)[i] = input_array->shape[i]; } @@ -682,7 +699,7 @@ void VirtualMachine::RunLoop() { } case Opcode::ReshapeTensor: { OpStartHook(instr); - Device cpu_dev = GetDevice(static_cast(kDLCPU)); + Device cpu_dev = GetDevice(exec_->host_device_index); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); // Read the shape from shape tensor @@ -703,14 +720,13 @@ void VirtualMachine::RunLoop() { } case Opcode::DeviceCopy: { OpStartHook(instr); - auto tensor_src = ReadRegister(instr.src); + auto tensor_src = ReadRegister(instr.device_copy.src); NDArray src_data = Downcast(tensor_src); - Device src_dev = src_data->device; - ICHECK_EQ(static_cast(src_dev.device_type), instr.src_device_type); - - Device dst_dev; - dst_dev.device_type = static_cast(instr.dst_device_type); - dst_dev.device_id = 0; + Device actual_src_dev = src_data->device; + Device inst_src_dev = GetDevice(instr.device_copy.src_device_index); + ICHECK_EQ(actual_src_dev.device_type, inst_src_dev.device_type); + ICHECK_EQ(actual_src_dev.device_id, inst_src_dev.device_id); + Device dst_dev = GetDevice(instr.device_copy.dst_device_index); NDArray dst_data = src_data.CopyTo(dst_dev); WriteRegister(instr.dst, dst_data); diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc index 8f263c3b3273..5df7984d003a 100644 --- a/tests/cpp/relay/transforms/device_domains_test.cc +++ b/tests/cpp/relay/transforms/device_domains_test.cc @@ -45,24 +45,32 @@ IRModule TestModule() { } TEST(DeviceDomains, SmokeTest) { - DeviceDomains domains; + SEScope cpu = SEScope::ForDeviceType(kDLCPU); + SEScope cuda = SEScope::ForDeviceType(kDLCUDA); + TargetMap target_map; + target_map.Set(Integer(static_cast(kDLCPU)), Target("llvm")); + target_map.Set(Integer(static_cast(kDLCUDA)), Target("cuda")); + transform::PassContext ctxt = transform::PassContext::Create(); + CompilationConfig config(ctxt, target_map, /*optional_host_target=*/{}); + DeviceDomains domains(config); IRModule mod = TestModule(); Function f = Downcast(mod->Lookup("f")); DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); - DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + DeviceDomainPtr result_domain = domains.Free(f->ret_type); std::vector arg_and_results; arg_and_results.push_back(x_domain); arg_and_results.push_back(y_domain); arg_and_results.push_back(result_domain); - DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); - domains.Unify(actual_add_domain, implied_add_domain); - domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + DeviceDomainPtr implied_add_domain = domains.MakeHigherOrderDomain(std::move(arg_and_results)); + EXPECT_FALSE(domains.UnifyOrNull(actual_add_domain, implied_add_domain) == nullptr); + EXPECT_FALSE(domains.UnifyOrNull( + x_domain, domains.ForSEScope(f->params[0]->checked_type(), cuda)) == nullptr); - EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); - EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); + EXPECT_EQ(domains.ResultSEScope(y_domain), config->CanonicalSEScope(cuda)); + EXPECT_EQ(domains.ResultSEScope(result_domain), config->CanonicalSEScope(cuda)); } } // namespace diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 58e559eb9680..8ba91976523a 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -26,14 +26,17 @@ def test_on_device_via_string(): assert isinstance(call, relay.Call) assert len(call.args) == 1 assert call.args[0] == x - assert call.attrs.device_type == 2 # ie kDLCUDA + assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.se_scope.virtual_device_id == 0 + assert call.attrs.se_scope.target is None + assert call.attrs.se_scope.memory_scope == "" assert not call.attrs.is_fixed def test_on_device_via_device(): x = relay.Var("x") - call = relay.annotation.on_device(x, tvm.device("llvm")) - assert call.attrs.device_type == 1 # ie kDLCPU + call = relay.annotation.on_device(x, tvm.device("cpu")) + assert call.attrs.se_scope.device_type_int == 1 # ie kDLCPU def test_on_device_invalid_device(): @@ -44,7 +47,7 @@ def test_on_device_invalid_device(): def test_on_device_is_fixed(): x = relay.Var("x") call = relay.annotation.on_device(x, "cuda", True) - assert call.attrs.device_type == 2 + assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA assert call.attrs.is_fixed @@ -54,15 +57,13 @@ def test_function_on_device(): f = relay.Function([x, y], relay.add(x, y)) func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") assert isinstance(func, relay.Function) - assert len(func.attrs["param_device_types"]) == 2 - assert func.attrs["param_device_types"][0] == 1 # ie kDLCPU - assert func.attrs["param_device_types"][1] == 2 # ie kDLCUDA - assert func.attrs["result_device_type"] == 2 # ie KDLCUDA + assert len(func.attrs["param_se_scopes"]) == 2 + assert func.attrs["param_se_scopes"][0].device_type_int == 1 # ie kDLCPU + assert func.attrs["param_se_scopes"][1].device_type_int == 2 # ie kDLCUDA + assert func.attrs["result_se_scope"].device_type_int == 2 # ie KDLCUDA if __name__ == "__main__": - test_on_device_via_string() - test_on_device_via_device() - test_on_device_invalid_device() - test_on_device_is_fixed() - test_function_on_device() + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/op/test_tensor.py b/tests/python/relay/op/test_tensor.py new file mode 100644 index 000000000000..4d2c1766972a --- /dev/null +++ b/tests/python/relay/op/test_tensor.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for tensor helpers.""" +import tvm +from tvm import relay +import pytest + + +def test_device_copy_via_string(): + x = relay.var("x") + call = relay.op.device_copy(x, "cuda", "cpu") + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.src_se_scope.virtual_device_id == 0 + assert call.attrs.src_se_scope.target is None + assert call.attrs.src_se_scope.memory_scope == "" + assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU + assert call.attrs.dst_se_scope.virtual_device_id == 0 + assert call.attrs.dst_se_scope.target is None + assert call.attrs.dst_se_scope.memory_scope == "" + + +def test_device_copy_via_device(): + x = relay.var("x") + call = relay.op.device_copy(x, tvm.device("cuda"), tvm.device("cpu")) + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index e3218ab1a829..37eb1a2d6456 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -26,18 +26,36 @@ import tvm.testing import numpy as np -CPU = tvm.device("cpu") # device_type=1 -GPU = tvm.device("cuda") # device_type=2 +HOST_DEVICE = tvm.device("cpu") +HOST_TARGET = tvm.target.Target("llvm") + +CPU_DEVICE = tvm.device("cpu") +CPU_TARGET = tvm.target.Target("llvm").with_host(HOST_TARGET) + +GPU_DEVICE = tvm.device("cuda") +GPU_TARGET = tvm.target.Target("cuda").with_host(HOST_TARGET) + +TARGETS = { + tvm.tir.IntImm("int32", CPU_DEVICE.device_type): CPU_TARGET, + tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET, +} + +HOST = tvm.target.make_se_scope(HOST_DEVICE, HOST_TARGET) # device_type=1 +CPU = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET) # device_type=1 +GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2 DEFAULT = GPU +CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) + core = tvm.IRModule() core.import_from_std("core.rly") def rewrite_and_assert(in_mod, expected_mod): """Manually run the pass and assert it's structurally equals to the expected.""" + config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) actual_mod = relay.transform.InferType()(in_mod) - actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.PlanDevices(config)(actual_mod) actual_mod = relay.transform.InferType()(actual_mod) expected_mod = relay.transform.InferType()(expected_mod) if not tvm.ir.structural_equal(actual_mod, expected_mod, True): @@ -59,7 +77,9 @@ def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): print("Not evaluating since GPU is not available") return with tvm.transform.PassContext(opt_level=3): - compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + compiled = relay.create_executor( + "vm", mod=in_mod, device=GPU_DEVICE, target=GPU_TARGET + ).evaluate() actual = compiled(*args).numpy() expected = reference_func(*args) tvm.testing.assert_allclose(actual, expected) @@ -85,9 +105,11 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a def test_plain(): + metatable = {"SEScope": [CPU, GPU]} + # Everything defaults to GPU def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], @@ -96,21 +118,28 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[2, 2, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -120,35 +149,44 @@ def ref(a, b, c, d): def test_left_add_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Force some args to be on CPU, rest default to GPU. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -158,35 +196,44 @@ def ref(a, b, c, d): def test_left_add_on_cpu_via_copy(): + metatable = {"SEScope": [CPU, GPU]} + # As for test_left_add_on_cpu, but with an explicit device_copy. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %1 = device_copy(%0, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -196,37 +243,46 @@ def ref(a, b, c, d): def test_both_adds_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = add(%c, %d); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=1); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][0]); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %2 = add(%c, %d); - %3 = on_device(%2, device_type=1, is_fixed=True); - %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%4, %5) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -236,34 +292,42 @@ def ref(a, b, c, d): def test_sharing(): + metatable = {"SEScope": [CPU, GPU]} + # The same add sub-expression is annotated twice. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1); - %2 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); + %2 = on_device(%0, se_scope=meta[SEScope][0]); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = on_device(%0, device_type=1, is_fixed=True); - %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%3, %4) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b): @@ -274,35 +338,44 @@ def ref(a, b): def test_let_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # The device for a let-bound expression can flow from uses of the let-bound var. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %l = add(%a, %b); let %r = add(%c, %d); - %0 = on_device(%l, device_type=1); + %0 = on_device(%l, se_scope=meta[SEScope][0]); subtract(%0, %r) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - let %l = on_device(%0, device_type=1, is_fixed=True); + let %l = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); let %r = add(%c, %d); - %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%1, %r) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -312,39 +385,49 @@ def ref(a, b, c, d): def test_func_param_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Devices for function parameters flow to call sites. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %f = fn (%x, %y) { %0 = add(%x, %y); - on_device(%0, device_type=1) + on_device(%0, se_scope=meta[SEScope][0]) }; %1 = %f(%a, %b); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1, 1], result_device_type=1) { - let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { + let %f = fn (%x, %y, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x, %y) }; %0 = %f(%a, %b); %1 = add(%c, %d); subtract(%0, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -354,9 +437,11 @@ def ref(a, b, c, d): def test_func_result_on_cpu(): + metatable = {"SEScope": [CPU, GPU]} + # Devices for call sites flow to function results. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], @@ -365,30 +450,38 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], add(%x, %y) }; %0 = %f(%a, %b); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = add(%c, %d); subtract(%1, %2) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=2) { - let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { + let %f = fn (%x, %y, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x, %y) }; %1 = %f(%a, %b); - %2 = on_device(%1, device_type=1, is_fixed=True); - %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = add(%c, %d); subtract(%3, %4) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -398,15 +491,17 @@ def ref(a, b, c, d): def test_higher_order(): + metatable = {"SEScope": [CPU, GPU]} + # The constraint on %a flows back to %y via %f and %h def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%g) { fn (%a) { - %0 = on_device(%a, device_type=1); + %0 = on_device(%a, se_scope=meta[SEScope][0]); %1 = %g(%0); add(%1, %x) } @@ -418,30 +513,36 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %3 = %2(%y); subtract(%x, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) { - let %f = fn (%g, param_device_types=[2], result_device_type=2) { - fn (%a, param_device_types=[1], result_device_type=2) { - %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + let %f = fn (%g, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + %0 = device_copy(%a, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %1 = %g(%0); add(%1, %x) } }; - let %h = fn (%b, param_device_types=[2], result_device_type=2) { + let %h = fn (%b, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { negative(%b) }; %2 = %f(%h); %3 = %2(%y); subtract(%x, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -457,14 +558,16 @@ def h(b): def test_function_in_tuple(): + metatable = {"SEScope": [CPU, GPU]} + # Since %f ends up in a tuple its argument and result is forced to be on the CPU def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { - %0 = on_device(%b, device_type=1); + %0 = on_device(%b, se_scope=meta[SEScope][0]); add(%a, %0) }; let %t = (%f, %x); @@ -472,17 +575,20 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %2 = %t.0; %2(%1, %y) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %b) }; let %t = (%f, %x); @@ -490,7 +596,10 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %1 = %t.0; %1(%0, %y) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -501,14 +610,14 @@ def ref(x, y): def test_device_copy(): const = rand((5, 7)) - metatable = {"relay.Constant": [relay.const(const)]} + metatable = {"SEScope": [CPU, GPU], "relay.Constant": [relay.const(const)]} def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { - %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%0, meta[relay.Constant][0]) } """, @@ -521,8 +630,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { - %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + def @main(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%0, meta[relay.Constant][0]) } """, @@ -538,31 +648,37 @@ def ref(x): def test_shape_func(): + metatable = {"SEScope": [HOST, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { %0 = fn (%y: Tensor[(?), float32]) { nn.relu(%y) }; - let %p = on_device(%0, device_type=2, is_fixed=True); - %1 = on_device(%x, device_type=2, is_fixed=True); + let %p = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); %2 = vm.shape_of(%1, dtype="int64"); %3 = (%2,); %4 = (%s,); vm.shape_func(%p, %3, %4, is_input=[False]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], - param_device_types=[2, 1], result_device_type=1) { - let %p = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + let %p = fn (%y: Tensor[(?), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { nn.relu(%y) }; %1 = vm.shape_of(%x, dtype="int64"); @@ -570,7 +686,10 @@ def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], %3 = (%s,); vm.shape_func(%p, %2, %3, is_input=[False]) } - """ + """, + "from_string", + None, + metatable, ) # Don't try to execute, too fiddly to setup. @@ -578,28 +697,37 @@ def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], def test_shape_of(): + metatable = {"SEScope": [HOST, GPU]} + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the # result defaults to the result device for @main which is the CPU, thus forcing a copy. # TODO(mbs): Perhaps the defaulting heuristics are being too clever? def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32]) { - %0 = on_device(%x, device_type=2, is_fixed=True); + %0 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); vm.shape_of(%0, dtype="int64") } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + def @main(%x: Tensor[(?, ?), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { vm.shape_of(%x, dtype="int64") } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -609,28 +737,33 @@ def ref(x): def test_alloc_storage(): + metatable = {"SEScope": [HOST, GPU]} + def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%size: int64, %alignment: int64) { - memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) } """, "from_string", core, + metatable, ) def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { - memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + def @main(%size: int64, %alignment: int64, + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) } """, "from_string", core, + metatable, ) # Don't try to execute, too fiddly to setup. @@ -639,7 +772,7 @@ def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_dev def test_alloc_tensor(): shape = np.array([3, 2]) - metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(shape, dtype="int64")]} def input(): return tvm.parser.parse( @@ -659,9 +792,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { - %0 = on_device(0, device_type=1, is_fixed=True); - %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + %0 = on_device(0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) } """, @@ -676,7 +809,7 @@ def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { def test_reshape_tensor(): newshape = [2, 4, 2] - metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(newshape, dtype="int64")]} def input(): return tvm.parser.parse( @@ -695,8 +828,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { - %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + def @main(%x: Tensor[(2, 8), float32], + param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } """, @@ -712,26 +846,34 @@ def ref(x): def test_dynamic_input(): + metatable = {"SEScope": [GPU]} + # There's nothing special about inferring devices for partially unknown types. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { add(%x0, %x1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], - param_device_types=[2, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%x0, %x1) } - """ + """, + "from_string", + None, + metatable, ) def ref(x0, x1): @@ -741,35 +883,44 @@ def ref(x0, x1): def test_redundant_annotation(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = subtract(%1, %z); - %3 = on_device(%0, device_type=1); + %3 = on_device(%0, se_scope=meta[SEScope][0]); add(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2], result_device_type=2) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %3 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %4 = subtract(%2, %z); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%4, %5) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -780,31 +931,40 @@ def ref(x, y, z): def test_annotate_expr(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][1]); %2 = subtract(%1, %z); - on_device(%2, device_type=1) + on_device(%2, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[2, 2, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2, is_fixed=True); - %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -814,17 +974,22 @@ def ref(x, y, z): def test_annotate_all(): + metatable = {"SEScope": [CPU, GPU]} + def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = subtract(%1, %z); - on_device(%2, device_type=1) + on_device(%2, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): @@ -832,11 +997,15 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); subtract(%0, %z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -858,43 +1027,52 @@ def test_conv_network(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=1); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][0]); %4 = add(%2, %3); - %5 = on_device(%4, device_type=2); + %5 = on_device(%4, se_scope=meta[SEScope][1]); %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - on_device(%6, device_type=1) + on_device(%6, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], - %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %weight: Tensor[(64, 64, 3, 3), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %1 = on_device(%0, device_type=1, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %3 = on_device(%2, device_type=1, is_fixed=True); - %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = add(%4, %5); - %7 = on_device(%6, device_type=2, is_fixed=True); - %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + %7 = on_device(%6, se_scope=meta[SEScope][1], is_fixed=True); + %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) } - """ + """, + "from_string", + None, + metatable, ) # Don't try to execute, we don't have a reference conv2d @@ -902,40 +1080,49 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 def test_tuple_get_item(): + metatable = {"SEScope": [CPU, GPU]} + # Note that the device copy should be placed after projection rather than before. This is handled by # a heuristic in the pass. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(3, 3, 4), float32]) { let %t = split(%x, indices_or_sections=3); - %0 = on_device(%t, device_type=1); - %1 = on_device(%t, device_type=1); + %0 = on_device(%t, se_scope=meta[SEScope][0]); + %1 = on_device(%t, se_scope=meta[SEScope][0]); %2 = %0.0; %3 = %1.1; %4 = subtract(%2, %3); - on_device(%4, device_type=2) + on_device(%4, se_scope=meta[SEScope][1]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + def @main(%x: Tensor[(3, 3, 4), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = split(%x, indices_or_sections=3); - let %t = on_device(%0, device_type=1, is_fixed=True); + let %t = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); %1 = %t.0; - %2 = on_device(%1, device_type=1, is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); %3 = %t.1; - %4 = on_device(%3, device_type=1, is_fixed=True); - %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); - %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%5, %6) } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -959,45 +1146,53 @@ def test_propogation(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { %0 = negative(%x); - %1 = on_device(%0, device_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][0]); %2 = negative(%1); - %3 = on_device(%0, device_type=1); + %3 = on_device(%0, se_scope=meta[SEScope][0]); %4 = negative(%3); - %5 = on_device(%2, device_type=2); - %6 = on_device(%4, device_type=2); + %5 = on_device(%2, se_scope=meta[SEScope][1]); + %6 = on_device(%4, se_scope=meta[SEScope][1]); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2); + %8 = on_device(%7, se_scope=meta[SEScope][1]); %9 = negative(%8); - on_device(%9, device_type=1) + on_device(%9, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + def @main(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = negative(%x); - %1 = on_device(%0, device_type=1, is_fixed=True); - %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); - %3 = on_device(%0, device_type=1, is_fixed=True); - %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = negative(%2); %6 = negative(%4); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2, is_fixed=True); - %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } - """ + """, + "from_string", + None, + metatable, ) def ref(x): @@ -1023,43 +1218,51 @@ def test_fusible_network(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2); + %1 = on_device(%0, se_scope=meta[SEScope][1]); %2 = negative(%1); - %3 = on_device(%2, device_type=1); + %3 = on_device(%2, se_scope=meta[SEScope][0]); %4 = negative(%0); %5 = add(%3, %4); - %6 = on_device(%5, device_type=2); + %6 = on_device(%5, se_scope=meta[SEScope][1]); %7 = negative(%6); - on_device(%7, device_type=1) + on_device(%7, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, device_type=2, is_fixed=True); - %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); %3 = negative(%2); - %4 = on_device(%3, device_type=1, is_fixed=True); - %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = negative(%0); %7 = add(%5, %6); - %8 = on_device(%7, device_type=2, is_fixed=True); - %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1083,37 +1286,45 @@ def test_unpropagatable_graph(): | <--- CPU """ + metatable = {"SEScope": [CPU, GPU]} def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = multiply(%c, %d); - %2 = on_device(%0, device_type=1); - %3 = on_device(%1, device_type=2); + %2 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%1, se_scope=meta[SEScope][1]); %4 = subtract(%2, %3); - on_device(%4, device_type=1) + on_device(%4, se_scope=meta[SEScope][0]) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_device_types=[1, 1, 2, 2], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], + result_se_scope=meta[SEScope][0]) { %0 = multiply(%c, %d); - %1 = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); %2 = add(%a, %b); - %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %3) } - """ + """, + "from_string", + None, + metatable, ) def ref(a, b, c, d): @@ -1123,14 +1334,16 @@ def ref(a, b, c, d): def test_conditional(): + metatable = {"SEScope": [CPU, GPU]} + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { let %f = fn (%a) { - %0 = on_device(%y, device_type=1, is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); add(%a, %0) }; let %g = fn (%a1) { @@ -1143,19 +1356,23 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { }; %h(%z) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_device_types=[1, 1, 1], result_device_type=1) { - let %f = fn (%a, param_device_types=[1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], + result_se_scope=meta[SEScope][0]) { + let %f = fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %y) }; - let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { subtract(%a1, %y) }; let %h = if (%x) { @@ -1165,7 +1382,10 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], }; %h(%z) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y, z): @@ -1182,36 +1402,46 @@ def g(a): def test_global(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { - %0 = on_device(%b, device_type=1); + %0 = on_device(%b, se_scope=meta[SEScope][0]); add(%a, %0) } def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { @f(%y, %x) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { - %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], + result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%a, %0) } def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + param_se_scopes=[meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { @f(%y, %x) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1224,33 +1454,41 @@ def f(a, b): def test_ref(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %r = ref(%x); - %0 = on_device(%y, device_type=1); + %0 = on_device(%y, se_scope=meta[SEScope][0]); ref_write(%r, %0); %1 = ref_read(%r); add(%x, %1) } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_device_types=[2, 1], result_device_type=2) { + param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { let %r = ref(%x); - %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); ref_write(%r, %0); %1 = ref_read(%r); add(%x, %1) } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): @@ -1263,8 +1501,10 @@ def ref(x, y): def test_adt(): + metatable = {"SEScope": [CPU, GPU]} + def input(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] type List[A] { @@ -1272,7 +1512,7 @@ def input(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { - %0 = on_device(%y, device_type=1, is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); %1 = Nil; %2 = Cons(%0, %1); let %l = Cons(%x, %2); @@ -1280,11 +1520,14 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { Cons(%z, _) => %z } } - """ + """, + "from_string", + None, + metatable, ) def expected(): - return tvm.parser.fromtext( + return tvm.parser.parse( """ #[version = "0.0.5"] type List[A] { @@ -1292,7 +1535,7 @@ def expected(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], - param_device_types=[1, 1], result_device_type=1) { + param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = Nil; %1 = Cons(%y, %0); let %l = Cons(%x, %1); @@ -1300,7 +1543,10 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Cons(%z, _) => %z } } - """ + """, + "from_string", + None, + metatable, ) def ref(x, y): diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 79979747dfd8..0961278b6402 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -853,11 +853,11 @@ def check_remote(server): # Get a handle to remote Executable. rexec = remote.load_module("vm_library.so") - ctx = remote.cpu() + device = remote.cpu() # Build a VM out of the executable and context. - vm_factory = runtime.vm.VirtualMachine(rexec, ctx) + vm_factory = runtime.vm.VirtualMachine(rexec, device) np_input = np.random.uniform(size=(10, 1)).astype("float32") - input_tensor = tvm.nd.array(np_input, ctx) + input_tensor = tvm.nd.array(np_input, device) # Invoke its "main" function. out = vm_factory.invoke("main", input_tensor) # Check the result. @@ -1003,7 +1003,10 @@ def test_shape_func_nested_function(): def test_storage_size_and_offset_on_cpu(): """Tests allocations place sizes and offsets on the CPU host even if the rest of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + # And/or move this to C++ and test the VM executable in it's C++ instead of + # pretty-printed form. # CPU = device type 1 # GPU = device type 2 @@ -1027,15 +1030,19 @@ def @main(%a: Tensor[(5, 7), float32], # - The size of the tensor's storage (first arg) to alloc_storage # - The offset of the tensor within the storage (second arg) to alloc_tensor # Both should be on the CPU - assert not "on device of type 2" in exe.constants - assert "on device of type 1" in exe.constants + assert "VirtualDevice[0]: device type 1" in exe.virtual_devices + assert "Constant[0]: has shape int64[] on device index 0" in exe.constants + assert "Constant[1]: has shape int64[] on device index 0" in exe.constants @tvm.testing.requires_cuda def test_reshape_shape_on_cpu(): """Tests the argument to a reshape places the shape on the CPU host even if the rest of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + # And/or move this to C++ and test the VM executable in it's C++ instead of + # pretty-printed form. # CPU = device type 1 # GPU = device type 2 @@ -1056,8 +1063,44 @@ def @main(%x: Tensor[(2, 8), float32], ) # The newshape annotation should have been turned into a constant on the CPU. - assert not "on device of type 2" in exe.constants - assert "on device of type 1" in exe.constants + assert "VirtualDevice[0]: device type 1" in exe.virtual_devices + assert "Constant[0]: has shape int64[3] on device index 0" in exe.constants + + +@tvm.testing.requires_cuda +def test_multi_targets(): + # Build an IRModule. + n = 10 + x = relay.var("x", shape=(n,)) + y = relay.var("y", shape=(n,)) + z = relay.var("z", shape=(n,)) + f = relay.Function([x, y, z], x + relay.op.annotation.on_device(y + z, tvm.cpu())) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + with tvm.transform.PassContext( + opt_level=3, config={"relay.fallback_device_type": tvm.cuda().device_type} + ): + exe = relay.vm.compile( + mod, target={"cpu": tvm.target.Target("llvm"), "cuda": tvm.target.Target("cuda")} + ) + + # Run + vm = runtime.vm.VirtualMachine(exe, [tvm.cuda(), tvm.cpu()]) + x_data = np.random.rand( + n, + ).astype("float32") + y_data = np.random.rand( + n, + ).astype("float32") + z_data = np.random.rand( + n, + ).astype("float32") + actual_result = vm.invoke("main", x_data, y_data, z_data) + + # Test + expected_result = x_data + y_data + z_data + tvm.testing.assert_allclose(actual_result.numpy(), expected_result) if __name__ == "__main__": diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 9eae3dd33672..04879573bd6a 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -424,16 +424,7 @@ def foo(): if __name__ == "__main__": - test_record_split_reorder_fuse_annotation() - test_record_compute_at_root_inline_cache_read_write() - test_record_follow_split_follow_fused_split() - test_record_pragma_storage_align_rfactor() - test_recover_measure_input() - test_workload_dis_factor() - test_measure_local_builder_runner() - test_dag_measure_local_builder_runner() - test_workload_serialization() - test_measure_local_builder_rpc_runner() - test_measure_target_host() - test_measure_special_inputs_map_by_name_local_runner() - test_measure_special_inputs_map_by_name_rpc_runner() + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 92c1174e728c..e2ce442e0d88 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -18,7 +18,6 @@ import datetime import json import os -import sys import tarfile import numpy @@ -411,4 +410,7 @@ def test_export_byoc_c_module(): if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + import sys + + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_export_operator_model_library_format() diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index b67142b42358..4e777435429b 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -196,4 +196,7 @@ def test_report_serialization(): if __name__ == "__main__": - test_papi("llvm", tvm.cpu()) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 75b61d281840..0499a3e6c65a 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -39,4 +39,7 @@ def test_basic(dev, target): if __name__ == "__main__": - test_basic(tvm.cpu(), tvm.target.Target("llvm")) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))