diff --git a/.gitignore b/.gitignore index 2f124d950b74..068cb87484a0 100644 --- a/.gitignore +++ b/.gitignore @@ -65,7 +65,7 @@ docs/_build/ docs/gen_modules # PyBuilder -target/ +/target/ # IPython Notebook .ipynb_checkpoints diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f6641761323..235d88ece2f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -128,6 +128,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) file(GLOB COMPILER_SRCS src/node/*.cc src/ir/*.cc + src/target/*.cc src/api/*.cc src/arithmetic/*.cc src/autotvm/*.cc diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 8b49fb78d6c3..891918864d44 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -24,157 +24,20 @@ #ifndef TVM_BUILD_MODULE_H_ #define TVM_BUILD_MODULE_H_ +#include + #include #include #include #include #include + #include "runtime/packed_func.h" #include "schedule_pass.h" #include "lowered_func.h" namespace tvm { -/*! -* \brief Container for target device information. -* Use target::llvm, target::cuda etc functions instead of constructing directly. -*/ -class TargetNode : public Object { - public: - /*! \brief The name of the target device */ - std::string target_name; - /*! \brief The name of the target device */ - std::string device_name; - /*! \brief The type of the target device */ - int device_type; - /*! \brief The maximum threads that a schedule should use for this device */ - int max_num_threads = 1; - /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ - int thread_warp_size = 1; - /*! \brief Keys for this target */ - Array keys_array; - /*! \brief Options for this target */ - Array options_array; - /*! \brief Collection of imported libs */ - Array libs_array; - - /*! \return the full device string to pass to codegen::Build */ - TVM_DLL const std::string& str() const; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("target_name", &target_name); - v->Visit("device_name", &device_name); - v->Visit("device_type", &device_type); - v->Visit("max_num_threads", &max_num_threads); - v->Visit("thread_warp_size", &thread_warp_size); - v->Visit("keys_array", &keys_array); - v->Visit("options_array", &options_array); - v->Visit("libs_array", &libs_array); - } - - /*! \brief Get the keys for this target as a vector of string */ - TVM_DLL std::vector keys() const; - - /*! \brief Get the options for this target as a vector of string */ - TVM_DLL std::vector options() const; - - /*! \brief Get the keys for this target as an unordered_set of string */ - TVM_DLL std::unordered_set libs() const; - - static constexpr const char* _type_key = "Target"; - TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); - - private: - /*! \brief Internal string repr. */ - mutable std::string str_repr_; -}; - -/*! \brief reference cpass to the target. */ -class Target : public ObjectRef { - public: - Target() {} - explicit Target(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Create a Target given a string - * \param target_str the string to parse - */ - TVM_DLL static Target Create(const std::string& target_str); - /*! - * \brief Get the current target context from thread local storage. - * \param allow_not_defined If the context stack is empty and this is set to true, an - * undefined Target will be returned. Otherwise, an empty context stack will cause a - * runtime error. - * \return The target that is the current context. The target may not be defined if - * allow_not_defined is true. - */ - TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - - const TargetNode* operator->() const { - return static_cast(get()); - } - - using ContainerType = TargetNode; - class Internal; - private: - // enable with syntax. - friend class Internal; - friend class With; - /*! - * \brief Push a new target context onto the thread local stack. - * The Target on top of the stack is used to determine which - * specialization to use when invoking a GenericFunc. - */ - TVM_DLL void EnterWithScope(); - /*! - * \brief Pop a target off the thread local context stack, - * restoring the previous target as the current context. - */ - TVM_DLL void ExitWithScope(); -}; - -/*! \brief This namespace provides functions to construct Target instances */ -namespace target { -/*! \return A target for LLVM */ -TVM_DLL Target llvm(const std::vector& options = - std::vector()); - -/*! \return A target for CUDA */ -TVM_DLL Target cuda(const std::vector& options = - std::vector()); - -/*! \return A target for ROCm */ -TVM_DLL Target rocm(const std::vector& options = - std::vector()); - -/*! \return A target for OpenCL */ -TVM_DLL Target opencl(const std::vector& options = - std::vector()); - -/*! \return A target for Metal */ -TVM_DLL Target metal(const std::vector& options = - std::vector()); - -/*! \return A target for rasp */ -TVM_DLL Target rasp(const std::vector& options = - std::vector()); - -/*! \return A target for Mali */ -TVM_DLL Target mali(const std::vector& options = - std::vector()); - -/*! \return A target for Intel Graphics */ -TVM_DLL Target intel_graphics(const std::vector& options = - std::vector()); - -/*! \return A target for stackvm */ -TVM_DLL Target stackvm(const std::vector& options = - std::vector()); - -/*! \return A target for external device */ -TVM_DLL Target ext_dev(const std::vector& options = - std::vector()); -} // namespace target - /*! * \brief Container for build configuration options */ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h new file mode 100644 index 000000000000..fd8ab68c6c17 --- /dev/null +++ b/include/tvm/target/target.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/target.h + * \brief Compilation target object. + */ +#ifndef TVM_TARGET_TARGET_H_ +#define TVM_TARGET_TARGET_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +/*! + * \brief Compilation target. + * \note Use target::llvm, target::cuda etc functions. + * \sa Target + */ +class TargetNode : public Object { + public: + /*! \brief The name of the target device */ + std::string target_name; + /*! \brief The name of the target device */ + std::string device_name; + /*! \brief The type of the target device */ + int device_type; + /*! \brief The maximum threads that a schedule should use for this device */ + int max_num_threads = 1; + /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ + int thread_warp_size = 1; + /*! \brief Keys for this target */ + Array keys_array; + /*! \brief Options for this target */ + Array options_array; + /*! \brief Collection of imported libs */ + Array libs_array; + + /*! \return the full device string to pass to codegen::Build */ + TVM_DLL const std::string& str() const; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("target_name", &target_name); + v->Visit("device_name", &device_name); + v->Visit("device_type", &device_type); + v->Visit("max_num_threads", &max_num_threads); + v->Visit("thread_warp_size", &thread_warp_size); + v->Visit("keys_array", &keys_array); + v->Visit("options_array", &options_array); + v->Visit("libs_array", &libs_array); + } + + /*! \brief Get the keys for this target as a vector of string */ + TVM_DLL std::vector keys() const; + + /*! \brief Get the options for this target as a vector of string */ + TVM_DLL std::vector options() const; + + /*! \brief Get the keys for this target as an unordered_set of string */ + TVM_DLL std::unordered_set libs() const; + + static constexpr const char* _type_key = "Target"; + TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); + + private: + /*! \brief Internal string repr. */ + mutable std::string str_repr_; +}; + +/*! + * \brief Managed reference class to TargetNode. + * \sa TargetNode + */ +class Target : public ObjectRef { + public: + Target() {} + explicit Target(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief Create a Target given a string + * \param target_str the string to parse + */ + TVM_DLL static Target Create(const std::string& target_str); + /*! + * \brief Get the current target context from thread local storage. + * \param allow_not_defined If the context stack is empty and this is set to true, an + * undefined Target will be returned. Otherwise, an empty context stack will cause a + * runtime error. + * \return The target that is the current context. The target may not be defined if + * allow_not_defined is true. + */ + TVM_DLL static tvm::Target Current(bool allow_not_defined = true); + + const TargetNode* operator->() const { + return static_cast(get()); + } + + using ContainerType = TargetNode; + class Internal; + private: + // enable with syntax. + friend class Internal; + friend class With; + /*! + * \brief Push a new target context onto the thread local stack. + * The Target on top of the stack is used to determine which + * specialization to use when invoking a GenericFunc. + */ + TVM_DLL void EnterWithScope(); + /*! + * \brief Pop a target off the thread local context stack, + * restoring the previous target as the current context. + */ + TVM_DLL void ExitWithScope(); +}; + +/*! \brief This namespace provides functions to construct Target instances */ +namespace target { + +/*! \return A target for LLVM */ +TVM_DLL Target llvm(const std::vector& options = + std::vector()); + +/*! \return A target for CUDA */ +TVM_DLL Target cuda(const std::vector& options = + std::vector()); + +/*! \return A target for ROCm */ +TVM_DLL Target rocm(const std::vector& options = + std::vector()); + +/*! \return A target for OpenCL */ +TVM_DLL Target opencl(const std::vector& options = + std::vector()); + +/*! \return A target for Metal */ +TVM_DLL Target metal(const std::vector& options = + std::vector()); + +/*! \return A target for rasp */ +TVM_DLL Target rasp(const std::vector& options = + std::vector()); + +/*! \return A target for Mali */ +TVM_DLL Target mali(const std::vector& options = + std::vector()); + +/*! \return A target for Intel Graphics */ +TVM_DLL Target intel_graphics(const std::vector& options = + std::vector()); + +/*! \return A target for stackvm */ +TVM_DLL Target stackvm(const std::vector& options = + std::vector()); + +/*! \return A target for external device */ +TVM_DLL Target ext_dev(const std::vector& options = + std::vector()); +} // namespace target +} // namespace tvm +#endif // TVM_TARGET_TARGET_H_ diff --git a/include/tvm/target_info.h b/include/tvm/target/target_info.h similarity index 90% rename from include/tvm/target_info.h rename to include/tvm/target/target_info.h index 0a42a76a1b2e..4466476a18de 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target/target_info.h @@ -18,14 +18,14 @@ */ /*! - * \file tvm/target_info.h + * \file tvm/target/target_info.h * \brief Various information about target. */ -#ifndef TVM_TARGET_INFO_H_ -#define TVM_TARGET_INFO_H_ +#ifndef TVM_TARGET_TARGET_INFO_H_ +#define TVM_TARGET_TARGET_INFO_H_ +#include #include -#include "expr.h" namespace tvm { @@ -33,7 +33,8 @@ namespace tvm { * \brief Memory information of special memory region. * Use MemoryInfo as its container type */ -struct MemoryInfoNode : public Object { +class MemoryInfoNode : public Object { + public: /*! \brief The addressable unit */ int unit_bits; /*! \brief Maximum number of bits supported in the memory */ @@ -71,4 +72,4 @@ class MemoryInfo : public ObjectRef { TVM_DLL MemoryInfo GetMemoryInfo(const std::string& scope); } // namespace tvm -#endif // TVM_TARGET_INFO_H_ +#endif // TVM_TARGET_TARGET_INFO_H_ diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 9f793424d233..771583b66a74 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -38,288 +38,8 @@ using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; -TVM_REGISTER_NODE_TYPE(TargetNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); - }); - - -/*! -* \brief Construct a Target node from the given name and options. -* \param target_name The major target name. Should be one of -* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal", -* "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} -* \param options Additional options appended to the target -* \return The constructed Target -*/ -Target CreateTarget(const std::string& target_name, - const std::vector& options) { - auto t = make_object(); - t->target_name = target_name; - - std::string libs_flag = "-libs="; - std::string device_flag = "-device="; - std::string keys_flag = "-keys="; - for (auto& item : options) { - t->options_array.push_back(ir::StringImmNode::make(item)); - - if (item.find(libs_flag) == 0) { - std::stringstream ss(item.substr(libs_flag.length())); - std::string lib_item; - while (std::getline(ss, lib_item, ',')) { - t->libs_array.push_back(ir::StringImmNode::make(lib_item)); - } - } else if (item.find(device_flag) == 0) { - t->device_name = item.substr(device_flag.length()); - t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); - } else if (item.find(keys_flag) == 0) { - std::stringstream ss(item.substr(keys_flag.length())); - std::string key_item; - while (std::getline(ss, key_item, ',')) { - t->keys_array.push_back(ir::StringImmNode::make(key_item)); - } - } - } - - if (t->device_name.length() > 0) { - t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); - } - t->device_type = kDLCPU; - t->thread_warp_size = 1; - if (target_name == "c" && t->device_name == "micro_dev") { - t->device_type = kDLMicroDev; - } else if (target_name == "c" || target_name == "llvm") { - t->keys_array.push_back(ir::StringImmNode::make("cpu")); - } else if (target_name == "cuda" || target_name == "nvptx") { - t->device_type = kDLGPU; - t->keys_array.push_back(ir::StringImmNode::make("cuda")); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); - t->max_num_threads = 1024; - t->thread_warp_size = 32; - } else if (target_name == "rocm" || target_name == "opencl") { - // For now assume rocm schedule for opencl - if (target_name == "opencl") { - t->device_type = kDLOpenCL; - } else { - t->device_type = kDLROCM; - } - t->keys_array.push_back(ir::StringImmNode::make(target_name)); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); - t->max_num_threads = 256; - if (t->device_name == "intel_graphics") { - t->thread_warp_size = 16; - } - } else if (target_name == "metal" || target_name == "vulkan") { - if (target_name == "metal") { - t->device_type = kDLMetal; - } else { - t->device_type = kDLVulkan; - } - t->keys_array.push_back(ir::StringImmNode::make(target_name)); - t->keys_array.push_back(ir::StringImmNode::make("gpu")); - t->max_num_threads = 256; - } else if (target_name == "sdaccel") { - t->device_type = kDLOpenCL; - t->keys_array.push_back(ir::StringImmNode::make("sdaccel")); - t->keys_array.push_back(ir::StringImmNode::make("hls")); - } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { - t->device_type = kDLAOCL; - t->keys_array.push_back(ir::StringImmNode::make("aocl")); - t->keys_array.push_back(ir::StringImmNode::make("hls")); - } else if (target_name == "opengl") { - t->device_type = kOpenGL; - t->keys_array.push_back(ir::StringImmNode::make("opengl")); - } else if (target_name == "stackvm") { - t->device_type = kDLCPU; - } else if (target_name == "ext_dev") { - t->device_type = kDLExtDev; - } else if (target_name == "hybrid") { - t->device_type = kDLCPU; - } else { - LOG(ERROR) << "Unknown target name " << target_name; - return target::stackvm(); - } - - return Target(t); -} - -TVM_REGISTER_GLOBAL("_TargetCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { - std::string target_name = args[0]; - std::vector options; - for (int i = 1; i < args.num_args; ++i) { - std::string arg = args[i]; - options.push_back(arg); - } - - *ret = CreateTarget(target_name, options); - }); - -TVM_REGISTER_GLOBAL("_TargetFromString") -.set_body([](TVMArgs args, TVMRetValue* ret) { - std::string target_str = args[0]; - *ret = Target::Create(target_str); - }); - -std::vector TargetNode::keys() const { - std::vector result; - for (auto& expr : keys_array) { - result.push_back(expr.as()->value); - } - return result; -} - -std::vector TargetNode::options() const { - std::vector result; - for (auto& expr : options_array) { - result.push_back(expr.as()->value); - } - return result; -} - -std::unordered_set TargetNode::libs() const { - std::unordered_set result; - for (auto& expr : libs_array) { - result.insert(expr.as()->value); - } - return result; -} - -const std::string& TargetNode::str() const { - if (str_repr_.length() != 0) return str_repr_; - std::ostringstream result; - result << target_name; - for (const auto &x : options()) { - result << " " << x; - } - str_repr_ = result.str(); - return str_repr_; -} - - -bool StartsWith(const std::string& str, const std::string& pattern) { - return str.compare(0, pattern.length(), pattern) == 0; -} - -std::string GetDeviceName(const std::string& target_str) { - std::istringstream ss(target_str); - std::string target_name; - ss >> target_name; - - std::string item; - while (ss >> item) { - if (StartsWith(item, "-device=")) { - return item.substr(std::string("-device=").length()); - } - } - - return ""; -} - -Target Target::Create(const std::string& target_str) { - if (target_str.length() == 0) { - LOG(ERROR) << "target_str must not be empty"; - } - - std::istringstream ss(target_str); - std::string target_name; - - ss >> target_name; - auto device_name = GetDeviceName(target_str); - - std::vector options; - std::string item; - while (ss >> item) { - options.push_back(item); - } - - return CreateTarget(target_name, options); -} - -/*! \brief Entry to hold the Target context stack. */ -struct TVMTargetThreadLocalEntry { - /*! \brief The current target context */ - std::stack context_stack; -}; - -/*! \brief Thread local store to hold the Target context stack. */ -typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; - -void Target::EnterWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - entry->context_stack.push(*this); -} - -void Target::ExitWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - CHECK(!entry->context_stack.empty()); - CHECK(entry->context_stack.top().same_as(*this)); - entry->context_stack.pop(); -} - -tvm::Target Target::Current(bool allow_not_defined) { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - if (entry->context_stack.size() > 0) { - return entry->context_stack.top(); - } - CHECK(allow_not_defined) - << "Target context required. Please set it by constructing a TargetContext"; - - return Target(); -} - -namespace target { -std::vector MergeOptions(std::vector opts, - const std::vector& new_opts) { - opts.insert(opts.end(), new_opts.begin(), new_opts.end()); - return opts; -} - -Target llvm(const std::vector& options) { - return CreateTarget("llvm", options); -} - -Target cuda(const std::vector& options) { - return CreateTarget("cuda", options); -} - -Target rocm(const std::vector& options) { - return CreateTarget("rocm", options); -} - -Target opencl(const std::vector& options) { - return CreateTarget("opencl", options); -} - -Target metal(const std::vector& options) { - return CreateTarget("metal", options); -} - -Target mali(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=mali" - })); -} - -Target intel_graphics(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=intel_graphics" - })); -} - -Target stackvm(const std::vector& options) { - return CreateTarget("stackvm", options); -} - -Target ext_dev(const std::vector& options) { - return CreateTarget("ext_dev", options); -} -} // namespace target - bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm"); return pf != nullptr; diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index d98299f24160..da153fc12804 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -21,7 +21,7 @@ * \file storage_access.cc */ #include -#include +#include #include #include #include "ir_util.h" diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 08c61aafbc0c..a6d83a825458 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include #include "ir_util.h" diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 7a4b13cb2cf5..49084209c07a 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index 956f27c9319d..002e42297faa 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/target/target.cc b/src/target/target.cc new file mode 100644 index 000000000000..014d3f9ff09a --- /dev/null +++ b/src/target/target.cc @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Compile executable modules. + * \file src/target/target.cc + */ +#include + +#include +#include +#include + +#include + +#include +#include + +namespace tvm { + +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; + +TVM_REGISTER_NODE_TYPE(TargetNode); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->str(); + }); + +/*! +* \brief Construct a Target node from the given name and options. +* \param target_name The major target name. Should be one of +* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal", +* "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} +* \param options Additional options appended to the target +* \return The constructed Target +*/ +Target CreateTarget(const std::string& target_name, + const std::vector& options) { + auto t = make_object(); + t->target_name = target_name; + + std::string libs_flag = "-libs="; + std::string device_flag = "-device="; + std::string keys_flag = "-keys="; + for (auto& item : options) { + t->options_array.push_back(ir::StringImmNode::make(item)); + + if (item.find(libs_flag) == 0) { + std::stringstream ss(item.substr(libs_flag.length())); + std::string lib_item; + while (std::getline(ss, lib_item, ',')) { + t->libs_array.push_back(ir::StringImmNode::make(lib_item)); + } + } else if (item.find(device_flag) == 0) { + t->device_name = item.substr(device_flag.length()); + t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); + } else if (item.find(keys_flag) == 0) { + std::stringstream ss(item.substr(keys_flag.length())); + std::string key_item; + while (std::getline(ss, key_item, ',')) { + t->keys_array.push_back(ir::StringImmNode::make(key_item)); + } + } + } + + if (t->device_name.length() > 0) { + t->keys_array.push_back(ir::StringImmNode::make(t->device_name)); + } + t->device_type = kDLCPU; + t->thread_warp_size = 1; + if (target_name == "c" && t->device_name == "micro_dev") { + t->device_type = kDLMicroDev; + } else if (target_name == "c" || target_name == "llvm") { + t->keys_array.push_back(ir::StringImmNode::make("cpu")); + } else if (target_name == "cuda" || target_name == "nvptx") { + t->device_type = kDLGPU; + t->keys_array.push_back(ir::StringImmNode::make("cuda")); + t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->max_num_threads = 1024; + t->thread_warp_size = 32; + } else if (target_name == "rocm" || target_name == "opencl") { + // For now assume rocm schedule for opencl + if (target_name == "opencl") { + t->device_type = kDLOpenCL; + } else { + t->device_type = kDLROCM; + } + t->keys_array.push_back(ir::StringImmNode::make(target_name)); + t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->max_num_threads = 256; + if (t->device_name == "intel_graphics") { + t->thread_warp_size = 16; + } + } else if (target_name == "metal" || target_name == "vulkan") { + if (target_name == "metal") { + t->device_type = kDLMetal; + } else { + t->device_type = kDLVulkan; + } + t->keys_array.push_back(ir::StringImmNode::make(target_name)); + t->keys_array.push_back(ir::StringImmNode::make("gpu")); + t->max_num_threads = 256; + } else if (target_name == "sdaccel") { + t->device_type = kDLOpenCL; + t->keys_array.push_back(ir::StringImmNode::make("sdaccel")); + t->keys_array.push_back(ir::StringImmNode::make("hls")); + } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { + t->device_type = kDLAOCL; + t->keys_array.push_back(ir::StringImmNode::make("aocl")); + t->keys_array.push_back(ir::StringImmNode::make("hls")); + } else if (target_name == "opengl") { + t->device_type = kOpenGL; + t->keys_array.push_back(ir::StringImmNode::make("opengl")); + } else if (target_name == "stackvm") { + t->device_type = kDLCPU; + } else if (target_name == "ext_dev") { + t->device_type = kDLExtDev; + } else if (target_name == "hybrid") { + t->device_type = kDLCPU; + } else { + LOG(ERROR) << "Unknown target name " << target_name; + return target::stackvm(); + } + + return Target(t); +} + +TVM_REGISTER_GLOBAL("_TargetCreate") +.set_body([](TVMArgs args, TVMRetValue* ret) { + std::string target_name = args[0]; + std::vector options; + for (int i = 1; i < args.num_args; ++i) { + std::string arg = args[i]; + options.push_back(arg); + } + + *ret = CreateTarget(target_name, options); + }); + +TVM_REGISTER_GLOBAL("_TargetFromString") +.set_body([](TVMArgs args, TVMRetValue* ret) { + std::string target_str = args[0]; + *ret = Target::Create(target_str); + }); + +std::vector TargetNode::keys() const { + std::vector result; + for (auto& expr : keys_array) { + result.push_back(expr.as()->value); + } + return result; +} + +std::vector TargetNode::options() const { + std::vector result; + for (auto& expr : options_array) { + result.push_back(expr.as()->value); + } + return result; +} + +std::unordered_set TargetNode::libs() const { + std::unordered_set result; + for (auto& expr : libs_array) { + result.insert(expr.as()->value); + } + return result; +} + +const std::string& TargetNode::str() const { + if (str_repr_.length() != 0) return str_repr_; + std::ostringstream result; + result << target_name; + for (const auto &x : options()) { + result << " " << x; + } + str_repr_ = result.str(); + return str_repr_; +} + + +bool StartsWith(const std::string& str, const std::string& pattern) { + return str.compare(0, pattern.length(), pattern) == 0; +} + +std::string GetDeviceName(const std::string& target_str) { + std::istringstream ss(target_str); + std::string target_name; + ss >> target_name; + + std::string item; + while (ss >> item) { + if (StartsWith(item, "-device=")) { + return item.substr(std::string("-device=").length()); + } + } + + return ""; +} + +Target Target::Create(const std::string& target_str) { + if (target_str.length() == 0) { + LOG(ERROR) << "target_str must not be empty"; + } + + std::istringstream ss(target_str); + std::string target_name; + + ss >> target_name; + auto device_name = GetDeviceName(target_str); + + std::vector options; + std::string item; + while (ss >> item) { + options.push_back(item); + } + + return CreateTarget(target_name, options); +} + +/*! \brief Entry to hold the Target context stack. */ +struct TVMTargetThreadLocalEntry { + /*! \brief The current target context */ + std::stack context_stack; +}; + +/*! \brief Thread local store to hold the Target context stack. */ +typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; + +void Target::EnterWithScope() { + TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + entry->context_stack.push(*this); +} + +void Target::ExitWithScope() { + TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); + entry->context_stack.pop(); +} + +tvm::Target Target::Current(bool allow_not_defined) { + TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + if (entry->context_stack.size() > 0) { + return entry->context_stack.top(); + } + CHECK(allow_not_defined) + << "Target context required. Please set it by constructing a TargetContext"; + + return Target(); +} + +namespace target { +std::vector MergeOptions(std::vector opts, + const std::vector& new_opts) { + opts.insert(opts.end(), new_opts.begin(), new_opts.end()); + return opts; +} + +Target llvm(const std::vector& options) { + return CreateTarget("llvm", options); +} + +Target cuda(const std::vector& options) { + return CreateTarget("cuda", options); +} + +Target rocm(const std::vector& options) { + return CreateTarget("rocm", options); +} + +Target opencl(const std::vector& options) { + return CreateTarget("opencl", options); +} + +Target metal(const std::vector& options) { + return CreateTarget("metal", options); +} + +Target mali(const std::vector& options) { + return CreateTarget("opencl", MergeOptions(options, { + "-device=mali" + })); +} + +Target intel_graphics(const std::vector& options) { + return CreateTarget("opencl", MergeOptions(options, { + "-device=intel_graphics" + })); +} + +Target stackvm(const std::vector& options) { + return CreateTarget("stackvm", options); +} + +Target ext_dev(const std::vector& options) { + return CreateTarget("ext_dev", options); +} +} // namespace target +} // namespace tvm diff --git a/src/lang/target_info.cc b/src/target/target_info.cc similarity index 75% rename from src/lang/target_info.cc rename to src/target/target_info.cc index 6bdcf8800967..6c332e77b9ba 100644 --- a/src/lang/target_info.cc +++ b/src/target/target_info.cc @@ -18,22 +18,22 @@ */ /*! - * \file target_info.cc + * \file target/target_info.cc */ #include -#include -#include +#include +#include namespace tvm { TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "mem-info(" - << "unit_bits=" << op->unit_bits << ", " - << "max_num_bits=" << op->max_num_bits << ", " - << "max_simd_bits=" << op->max_simd_bits << ", " - << "head_address=" << op->head_address << ")"; + auto* op = static_cast(node.get()); + p->stream << "mem-info(" + << "unit_bits=" << op->unit_bits << ", " + << "max_num_bits=" << op->max_num_bits << ", " + << "max_simd_bits=" << op->max_simd_bits << ", " + << "head_address=" << op->head_address << ")"; }); TVM_REGISTER_NODE_TYPE(MemoryInfoNode);