diff --git a/include/tvm/target/target_device.h b/include/tvm/target/target_device.h new file mode 100644 index 000000000000..13e2aba5db31 --- /dev/null +++ b/include/tvm/target/target_device.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/target_device.h + * \brief A compile time representation of a target device. + * + * This data structure consists of both the compiler target and a virtual device, + * a tvm::Device where the the identifier is a virtual identifier and a concrete + * device type. + * + * Executors are required to handle how to map virtual device identifiers to physical + * device identifiers. + * + * The reason to introduce this data structure is that for much of compilation we + * require understanding both of the target that we plan to compile the code for + * as well as the concrete device which is used to initiate copies and other + * device API actions. + * + * The idea is that we will carry around TargetDevice structures until device and + * target planning at which time we can inject explicit virtual devices in the + * program, and annotate explicit targets on the code to be generated. + * + * This will enable us to mix and match multiple devices of the same type with + * different targets or compilation options, and eventually resolve to a phyical + * set of devices with code specialized using the correct target. + * + * For example consider mobile SoCs which may contain two CPU types, a mobile GPU, + * as well as NPU accelerator. It is important in these cases for us to be able to + * correctly partition the code for each device type and apply different compilation + * strategies. + * + * Today the compiler maps each device "type" to a single target, which does not work + * when you have multiple types of CPUs, GPUs or accelerators attached. + */ +#ifndef TVM_TARGET_TARGET_DEVICE_H_ +#define TVM_TARGET_TARGET_DEVICE_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { + +class TargetDevice; + +/*! + * \brief A representation of both the compile time and runtime data structure needed to represent a device. + * \sa TargetDevice + */ +class TargetDeviceNode : public Object { + public: + /*! \brief The compilation target to use for the device. */ + Target target; + + /*! \brief The virtual device idenitfier which must be resolved to a physical device identifier before execution. */ + int virtual_device_id; + + /*! \brief The device type. */ + DLDeviceType device_type; + + + void VisitAttrs(AttrVisitor* v) { + v->Visit("target", &target); + v->Visit("virtual_device_id", &virtual_device_id); + DLDeviceType* ptr = &device_type; + v->Visit("device_type", reinterpret_cast(ptr)); + } + + static constexpr const char* _type_key = "TargetDevice"; + TVM_DECLARE_FINAL_OBJECT_INFO(TargetDeviceNode, Object); +}; + +/*! + * \brief Managed reference class to TargetDeviceNode. + * \sa TargetDeviceNode + * + * This data structure consists of both the compiler target and a virtual device, + * a tvm::Device where the the identifier is a virtual identifier and a concrete + * device type. + */ +class TargetDevice : public ObjectRef { + public: + /*! + * \brief Construct a TargetDevice. + * \param target The target to compile for. + * \param host The virtual device to execute on. + * \return The TargetDevice. + */ + TVM_DLL explicit TargetDevice(Target target, Device virtual_device); + TVM_DLL operator Device(); + TVM_DEFINE_OBJECT_REF_METHODS(TargetDevice, ObjectRef, TargetDeviceNode); +}; + +} // namespace tvm +#endif // TVM_TARGET_TARGET_DEVICE_H_ diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 6294e7acea15..418d9251ffe9 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -728,13 +728,18 @@ def PartitionGraph(mod_name="default"): def AnnotateTarget(targets, include_non_call_ops=True): - """Annotate ops in an experession with a provied compiler/target and then - use it for codegen. + """ + The annotate the operations in an expression with the provided compiler target + allowing the annotated expressions to be lowered for the provided target. + + For example if you annotate `tensorrt` the offloaded operators will be + lowered and executed by tensorrt instead of the standard lowering and runtime. Parameters ---------- targets : str or List[str] The list of target compilers used for codegen. + include_non_call_ops : boolean If True then non-call ops also will be annotated with targets If False then non-call ops will not be processed diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 7056dfe79fee..d36f4a993e36 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -204,6 +204,10 @@ using FForwardRewrite = TypedPackedFunc Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); @@ -243,7 +247,7 @@ class ForwardPrep : private MixedModeVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* op) { + void VisitExpr_(const LetNode* op) override { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue @@ -256,13 +260,13 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) override { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } - void VisitExpr_(const CallNode* call) { + void VisitExpr_(const CallNode* call) override { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { @@ -292,7 +296,7 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) override { ExprVisitor::VisitExpr_(op); // do not support pass scale through tuple for now. auto flazy = [this, op]() { @@ -303,7 +307,7 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) override { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d7..37370b041829 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * Compile executable modules. + * \brief Implementation of methods, and FFI interfaces for the compilation target object. * \file src/target/target.cc */ #include diff --git a/src/target/target_device.cc b/src/target/target_device.cc new file mode 100644 index 000000000000..91f2d8609395 --- /dev/null +++ b/src/target/target_device.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \brief The implementation of the TargetDevice object for representing compilation target + virtual device. + * \file src/target/target_device.cc + */ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(TargetDeviceNode); + +TargetDevice::TargetDevice(Target target, Device virtual_device) { + auto object = make_object(); + object->target = target; + object->virtual_device_id = virtual_device.device_id; + object->device_type = virtual_device.device_type; + data_ = std::move(object); +} + +TargetDevice::operator Device() { + return Device { .device_id = (*this)->virtual_device_id, + .device_type = (*this)->device_type }; +} + +} // namespace tvm