diff --git a/include/tvm/relay/executor.h b/include/tvm/relay/executor.h new file mode 100644 index 000000000000..4f779e1dc0a4 --- /dev/null +++ b/include/tvm/relay/executor.h @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/executor.h + * \brief Object representation of Executor configuration and registry + */ +#ifndef TVM_RELAY_EXECUTOR_H_ +#define TVM_RELAY_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +template +class AttrRegistry; + +namespace relay { + +/*! + * \brief Executor information. + * + * This data structure stores the meta-data + * about executors which can be used to pass around information. + * + * \sa Executor + */ +class ExecutorNode : public Object { + public: + /*! \brief name of the Executor */ + String name; + /* \brief Additional attributes storing meta-data about the Executor. */ + DictAttrs attrs; + + /*! + * \brief Get an attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TObjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const Executor& executor) { + * auto value = executor->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("attrs", &attrs); + } + + bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { + return name == other->name && equal.DefEqual(attrs, other->attrs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(attrs); + } + + static constexpr const char* _type_key = "Executor"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); +}; + +/*! + * \brief Managed reference class to ExecutorNode. + * \sa ExecutorNode + */ +class Executor : public ObjectRef { + public: + /*! + * \brief Create a new Executor object using the registry + * \throws Error if name is not registered + * \param name The name of the executor. + * \param attrs Attributes for the executor. + * \return the new Executor object. + */ + TVM_DLL static Executor Create(String name, Map attrs); + + /*! + * \brief List all registered Executors + * \return the list of Executors + */ + TVM_DLL static Array ListExecutors(); + + /*! + * \brief List all options for a specific Executor + * \param name The name of the Executor + * \return Map of option name to type + */ + TVM_DLL static Map ListExecutorOptions(const String& name); + + /*! \brief specify container node */ + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); + + private: + /*! + * \brief Private Constructor + * \param name The executor name + * \param attrs Attributes to apply to this Executor node + */ + TVM_DLL Executor(String name, DictAttrs attrs) { + auto n = make_object(); + n->name = std::move(name); + n->attrs = std::move(attrs); + data_ = std::move(n); + } +}; + +/*! + * \brief Helper structure to register Executors + * \sa TVM_REGISTER_EXECUTOR + */ +class ExecutorRegEntry { + public: + /*! \brief Set name of the Executor to be the same as registry if it is empty */ + inline ExecutorRegEntry& set_name(); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \tparam ValueType The value type to be registered + */ + template + inline ExecutorRegEntry& add_attr_option(const String& key); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \param default_value The default value of the key + * \tparam ValueType The value type to be registered + */ + template + inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); + + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); + + private: + /*! \brief Internal storage of value types */ + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + std::unordered_map key2vtype_; + /*! \brief A hash table that stores the default value of each attr */ + std::unordered_map key2default_; + + /*! \brief Index used for internal lookup of attribute registry */ + uint32_t index_; + + // the name + std::string name; + + /*! \brief Return the index stored in attr registry */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief Return the name stored in attr registry */ + String AttrRegistryName() const { return name; } + + /*! \brief private constructor */ + explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} + + // friend class + template + friend class AttrRegistryMapContainerMap; + template + friend class tvm::AttrRegistry; + friend class Executor; +}; + +inline ExecutorRegEntry& ExecutorRegEntry::set_name() { + if (name.empty()) { + name = name; + } + return *this; +} + +template +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { + ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key + << "' has been set once"; + + using ValueNodeType = typename ValueType::ContainerType; + // NOTE: we could further update the function later. + uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + return *this; +} + +template +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, + ObjectRef default_value) { + add_attr_option(key); + key2default_[key] = default_value; + return *this; +} + +// internal macros to make executor entries +#define TVM_EXECUTOR_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor + +/*! + * \def TVM_REGISTER_EXECUTOR + * \brief Register a new executor, or set attribute of the corresponding executor. + * + * \param ExecutorName The name of registry + * + * \code + * + * TVM_REGISTER_EXECUTOR("aot") + * .add_attr_option("my_option"); + * .add_attr_option("my_option_default", String("default")); + * + * \endcode + */ +#define TVM_REGISTER_EXECUTOR(ExecutorName) \ + TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName).set_name() +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_EXECUTOR_H_ diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h new file mode 100644 index 000000000000..cc2ea4193ff2 --- /dev/null +++ b/include/tvm/relay/runtime.h @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/runtime.h + * \brief Object representation of Runtime configuration and registry + */ +#ifndef TVM_RELAY_RUNTIME_H_ +#define TVM_RELAY_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +template +class AttrRegistry; + +namespace relay { + +/*! + * \brief Runtime information. + * + * This data structure stores the meta-data + * about Runtimes which can be used to pass around information. + * + * \sa Runtime + */ +class RuntimeNode : public Object { + public: + /*! \brief name of the Runtime */ + String name; + /* \brief Additional attributes storing meta-data about the Runtime. */ + DictAttrs attrs; + + /*! + * \brief Get an attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TObjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const Runtime& runtime) { + * auto value = runtime->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("attrs", &attrs); + } + + bool SEqualReduce(const RuntimeNode* other, SEqualReducer equal) const { + return name == other->name && equal.DefEqual(attrs, other->attrs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(attrs); + } + + static constexpr const char* _type_key = "Runtime"; + TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object); +}; + +/*! + * \brief Managed reference class to RuntimeNode. + * \sa RuntimeNode + */ +class Runtime : public ObjectRef { + public: + /*! + * \brief Create a new Runtime object using the registry + * \throws Error if name is not registered + * \param name The name of the Runtime. + * \param attrs Attributes for the Runtime. + * \return the new Runtime object. + */ + TVM_DLL static Runtime Create(String name, Map attrs); + + /*! + * \brief List all registered Runtimes + * \return the list of Runtimes + */ + TVM_DLL static Array ListRuntimes(); + + /*! + * \brief List all options for a specific Runtime + * \param name The name of the Runtime + * \return Map of option name to type + */ + TVM_DLL static Map ListRuntimeOptions(const String& name); + + /*! \brief specify container node */ + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Runtime, ObjectRef, RuntimeNode); + + private: + /*! + * \brief Private Constructor + * \param name The Runtime name + * \param attrs Attributes to apply to this Runtime node + */ + TVM_DLL Runtime(String name, DictAttrs attrs) { + auto n = make_object(); + n->name = std::move(name); + n->attrs = std::move(attrs); + data_ = std::move(n); + } +}; + +/*! + * \brief Helper structure to register Runtimes + * \sa TVM_REGISTER_Runtime + */ +class RuntimeRegEntry { + public: + /*! \brief Set name of the Runtime to be the same as registry if it is empty */ + inline RuntimeRegEntry& set_name(); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \tparam ValueType The value type to be registered + */ + template + inline RuntimeRegEntry& add_attr_option(const String& key); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \param default_value The default value of the key + * \tparam ValueType The value type to be registered + */ + template + inline RuntimeRegEntry& add_attr_option(const String& key, ObjectRef default_value); + + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static RuntimeRegEntry& RegisterOrGet(const String& name); + + private: + /*! \brief Internal storage of value types */ + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + std::unordered_map key2vtype_; + /*! \brief A hash table that stores the default value of each attr */ + std::unordered_map key2default_; + + /*! \brief Index used for internal lookup of attribute registry */ + uint32_t index_; + + // the name + std::string name; + + /*! \brief Return the index stored in attr registry */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief Return the name stored in attr registry */ + String AttrRegistryName() const { return name; } + + /*! \brief private constructor */ + explicit RuntimeRegEntry(uint32_t reg_index) : index_(reg_index) {} + + // friend class + template + friend class AttrRegistryMapContainerMap; + template + friend class tvm::AttrRegistry; + friend class Runtime; +}; + +inline RuntimeRegEntry& RuntimeRegEntry::set_name() { + if (name.empty()) { + name = name; + } + return *this; +} + +template +inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key) { + ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key + << "' has been set once"; + + using ValueNodeType = typename ValueType::ContainerType; + // NOTE: we could further update the function later. + uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + return *this; +} + +template +inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key, + ObjectRef default_value) { + add_attr_option(key); + key2default_[key] = default_value; + return *this; +} + +// internal macros to make Runtime entries +#define TVM_RUNTIME_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::RuntimeRegEntry& __make_##Runtime + +/*! + * \def TVM_REGISTER_RUNTIME + * \brief Register a new Runtime, or set attribute of the corresponding Runtime. + * + * \param RuntimeName The name of registry + * + * \code + * + * TVM_REGISTER_RUNTIME("c") + * .add_attr_option("my_option"); + * .add_attr_option("my_option_default", String("default")); + * + * \endcode + */ +#define TVM_REGISTER_RUNTIME(RuntimeName) \ + TVM_STR_CONCAT(TVM_RUNTIME_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::RuntimeRegEntry::RegisterOrGet(RuntimeName).set_name() +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_RUNTIME_H_ diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index d76459236515..b6a402b0f30f 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -16,3 +16,5 @@ # under the License. """Backend codegen modules for relay.""" from . import te_compiler +from .executor import Executor +from .runtime import Runtime diff --git a/python/tvm/relay/backend/executor.py b/python/tvm/relay/backend/executor.py new file mode 100644 index 000000000000..b3af565fe69e --- /dev/null +++ b/python/tvm/relay/backend/executor.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=len-as-condition,no-else-return,invalid-name +"""Executor configuration""" + +import tvm +from tvm.runtime import Object + +from . import _backend + + +@tvm._ffi.register_object +class Executor(Object): + """Executor configuration""" + + def __init__(self, name, options=None) -> None: + if options is None: + options = {} + self.__init_handle_by_constructor__(_backend.CreateExecutor, name, options) + self._attrs = _backend.GetExecutorAttrs(self) + + def __contains__(self, name): + return name in self._attrs + + def __getitem__(self, name): + return self._attrs[name] + + @staticmethod + def list_executors(): + """Returns a list of possible executors""" + return list(_backend.ListExecutors()) + + @staticmethod + def list_executor_options(executor): + """Returns the dict of available option names and types""" + return dict(_backend.ListExecutorOptions(str(executor))) diff --git a/python/tvm/relay/backend/runtime.py b/python/tvm/relay/backend/runtime.py new file mode 100644 index 000000000000..81779a245dde --- /dev/null +++ b/python/tvm/relay/backend/runtime.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=len-as-condition,no-else-return,invalid-name +"""Runtime configuration""" + +import tvm +from tvm.runtime import Object + +from . import _backend + + +@tvm._ffi.register_object +class Runtime(Object): + """Runtime configuration""" + + def __init__(self, name, options=None) -> None: + if options is None: + options = {} + self.__init_handle_by_constructor__(_backend.CreateRuntime, name, options) + self._attrs = _backend.GetRuntimeAttrs(self) + + def __contains__(self, name): + return name in self._attrs + + def __getitem__(self, name): + return self._attrs[name] + + @staticmethod + def list_runtimes(): + """Returns a list of possible runtimes""" + return list(_backend.ListRuntimes()) + + @staticmethod + def list_runtime_options(runtime): + """Returns the dict of available option names and types""" + return dict(_backend.ListRuntimeOptions(str(runtime))) diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc new file mode 100644 index 000000000000..3f5c2f4cb00f --- /dev/null +++ b/src/relay/backend/executor.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/executor.cc + * \brief Executor Registry + */ + +#include + +#include "../../node/attr_registry.h" +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ExecutorNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const Executor& executor = Downcast(obj); + p->stream << executor->name; + }); + +/********** Registry-related code **********/ + +using ExecutorRegistry = AttrRegistry; + +Executor Executor::Create(String name, Map attrs) { + const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Executor \"" + name + "\" is not defined"); + } + + for (const auto& kv : attrs) { + if (!reg->key2vtype_.count(kv.first)) { + throw Error("Attribute \"" + kv.first + "\" is not available on this Executor"); + } + std::string expected_type = reg->key2vtype_.at(kv.first).type_key; + std::string actual_type = kv.second->GetTypeKey(); + if (expected_type != actual_type) { + throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + + "\" but instead found \"" + actual_type + "\""); + } + } + + for (const auto& kv : reg->key2default_) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + + return Executor(name, DictAttrs(attrs)); +} + +Array Executor::ListExecutors() { return ExecutorRegistry::Global()->ListAllNames(); } + +Map Executor::ListExecutorOptions(const String& name) { + Map options; + const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Executor \"" + name + "\" is not defined"); + } + for (const auto& kv : reg->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + +ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { + return ExecutorRegistry::Global()->RegisterOrGet(name); +} + +/********** Register Executors and options **********/ + +TVM_REGISTER_EXECUTOR("aot") + .add_attr_option("unpacked-api") + .add_attr_option("interface-api"); + +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); + +TVM_REGISTER_EXECUTOR("vm"); + +/********** Registry **********/ + +TVM_REGISTER_GLOBAL("relay.backend.CreateExecutor").set_body_typed(Executor::Create); +TVM_REGISTER_GLOBAL("relay.backend.GetExecutorAttrs").set_body_typed([](const Executor& executor) { + return executor->attrs->dict; +}); + +TVM_REGISTER_GLOBAL("relay.backend.ListExecutors").set_body_typed(Executor::ListExecutors); +TVM_REGISTER_GLOBAL("relay.backend.ListExecutorOptions") + .set_body_typed(Executor::ListExecutorOptions); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc new file mode 100644 index 000000000000..1c08cbd29d1e --- /dev/null +++ b/src/relay/backend/runtime.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/runtime.cc + * \brief Runtime Registry + */ + +#include + +#include "../../node/attr_registry.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(RuntimeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const Runtime& runtime = Downcast(obj); + p->stream << runtime->name; + }); + +/********** Registry-related code **********/ + +using RuntimeRegistry = AttrRegistry; + +Runtime Runtime::Create(String name, Map attrs) { + const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Runtime \"" + name + "\" is not defined"); + } + + for (const auto& kv : attrs) { + if (!reg->key2vtype_.count(kv.first)) { + throw Error("Attribute \"" + kv.first + "\" is not available on this Runtime"); + } + std::string expected_type = reg->key2vtype_.at(kv.first).type_key; + std::string actual_type = kv.second->GetTypeKey(); + if (expected_type != actual_type) { + throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + + "\" but instead found \"" + actual_type + "\""); + } + } + + for (const auto& kv : reg->key2default_) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + + return Runtime(name, DictAttrs(attrs)); +} + +Array Runtime::ListRuntimes() { return RuntimeRegistry::Global()->ListAllNames(); } + +Map Runtime::ListRuntimeOptions(const String& name) { + Map options; + const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Runtime \"" + name + "\" is not defined"); + } + for (const auto& kv : reg->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + +RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { + return RuntimeRegistry::Global()->RegisterOrGet(name); +} + +/********** Register Runtimes and options **********/ + +TVM_REGISTER_RUNTIME("c").add_attr_option("system-lib"); + +TVM_REGISTER_RUNTIME("cpp"); + +/********** Registry **********/ + +TVM_REGISTER_GLOBAL("relay.backend.CreateRuntime").set_body_typed(Runtime::Create); +TVM_REGISTER_GLOBAL("relay.backend.GetRuntimeAttrs").set_body_typed([](const Runtime& runtime) { + return runtime->attrs->dict; +}); + +TVM_REGISTER_GLOBAL("relay.backend.ListRuntimes").set_body_typed(Runtime::ListRuntimes); +TVM_REGISTER_GLOBAL("relay.backend.ListRuntimeOptions").set_body_typed(Runtime::ListRuntimeOptions); + +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/backend/executor_test.cc b/tests/cpp/relay/backend/executor_test.cc new file mode 100644 index 000000000000..3367390b27f2 --- /dev/null +++ b/tests/cpp/relay/backend/executor_test.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +namespace tvm { +namespace relay { + +TVM_REGISTER_EXECUTOR("TestExecutor") + .add_attr_option("my_bool") + .add_attr_option>("your_names") + .add_attr_option("another_option") + .add_attr_option("defaulty_the_default_option", Bool(false)); + +TEST(Executor, Create) { + Map attrs = {{"my_bool", Bool(true)}}; + Executor my_exec = Executor::Create("TestExecutor", attrs); + ASSERT_EQ(my_exec->GetAttr("my_bool"), true); + ASSERT_EQ(my_exec->GetAttr>("your_names").defined(), false); + ASSERT_EQ(my_exec->GetAttr("defaulty_the_default_option"), false); +} + +TEST(Executor, UnknownAttr) { + Map attrs = {{"woofles", Bool(true)}}; + ASSERT_THROW(Executor::Create("TestExecutor", attrs), Error); +} + +TEST(Executor, IncorrectAttrType) { + Map attrs = {{"my_bool", String("snuck_in")}}; + ASSERT_THROW(Executor::Create("TestExecutor", attrs), Error); +} + +TEST(Executor, UnregisteredName) { + Map attrs = {}; + ASSERT_THROW(Executor::Create("NeverNameAnExecutorThis", attrs), Error); +} + +TEST(ExecutorRegistry, ListExecutors) { + Array names = Executor::ListExecutors(); + ICHECK_EQ(names.empty(), false); + ICHECK_EQ(std::count(std::begin(names), std::end(names), "TestExecutor"), 1); +} + +TEST(ExecutorRegistry, ListExecutorOptions) { + Map attrs = Executor::ListExecutorOptions("TestExecutor"); + + ICHECK_EQ(attrs.empty(), false); + ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["your_names"], "Array"); + ICHECK_EQ(attrs["another_option"], "runtime.String"); +} + +TEST(ExecutorRegistry, ListExecutorOptionsNoExecutor) { + ASSERT_THROW(Executor::ListExecutorOptions("NeverNameAnExecutorThis"), Error); +} + +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc new file mode 100644 index 000000000000..53ea7e39ed59 --- /dev/null +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +namespace tvm { +namespace relay { + +TVM_REGISTER_RUNTIME("TestRuntime") + .add_attr_option("my_bool") + .add_attr_option>("your_names") + .add_attr_option("another_option") + .add_attr_option("defaulty_the_default_option", Bool(false)); + +TEST(Runtime, Create) { + Map attrs = {{"my_bool", Bool(true)}}; + Runtime my_runtime = Runtime::Create("TestRuntime", attrs); + ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); + ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); + ASSERT_EQ(my_runtime->GetAttr("defaulty_the_default_option"), false); +} + +TEST(Runtime, UnknownAttr) { + Map attrs = {{"woofles", Bool(true)}}; + ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); +} + +TEST(Runtime, IncorrectAttrType) { + Map attrs = {{"my_bool", String("snuck_in")}}; + ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); +} + +TEST(Runtime, UnregisteredName) { + Map attrs = {}; + ASSERT_THROW(Runtime::Create("NeverNameAnRuntimeThis", attrs), Error); +} + +TEST(RuntimeRegistry, ListRuntimes) { + Array names = Runtime::ListRuntimes(); + ICHECK_EQ(names.empty(), false); + ICHECK_EQ(std::count(std::begin(names), std::end(names), "TestRuntime"), 1); +} + +TEST(RuntimeRegistry, ListRuntimeOptions) { + Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); + + ICHECK_EQ(attrs.empty(), false); + ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["your_names"], "Array"); + ICHECK_EQ(attrs["another_option"], "runtime.String"); +} + +TEST(RuntimeRegistry, ListRuntimeOptionsNoRuntime) { + ASSERT_THROW(Runtime::ListRuntimeOptions("NeverNameAnRuntimeThis"), Error); +} + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py new file mode 100644 index 000000000000..ebda4ff47cac --- /dev/null +++ b/tests/python/relay/test_executor.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm import TVMError +from tvm.relay.backend import Executor + + +def test_create_executor(): + executor = Executor("aot") + assert str(executor) == "aot" + + +def test_create_executor_with_options(): + executor = Executor("aot", {"interface-api": "c"}) + assert str(executor) == "aot" + assert executor["interface-api"] == "c" + + +def test_create_executor_with_default(): + executor = Executor("graph") + assert not executor["link-params"] + + +def test_attr_check(): + executor = Executor("aot", {"interface-api": "c"}) + assert "woof" not in executor + assert "interface-api" in executor + + +def test_create_executor_not_found(): + with pytest.raises(TVMError, match='Executor "woof" is not defined'): + Executor("woof", {}) + + +def test_create_executor_attr_not_found(): + with pytest.raises(TVMError, match='Attribute "woof" is not available on this Executor'): + Executor("aot", {"woof": "bark"}) + + +def test_create_executor_attr_type_incorrect(): + with pytest.raises( + TVMError, + match='Attribute "interface-api" should have type "runtime.String"' + ' but instead found "IntImm"', + ): + Executor("aot", {"interface-api": True}) + + +def test_list_executors(): + assert "aot" in Executor.list_executors() + + +@pytest.mark.parametrize("executor", [Executor("aot"), "aot"]) +def test_list_executor_options(executor): + aot_options = Executor.list_executor_options(executor) + assert "interface-api" in aot_options + assert aot_options["interface-api"] == "runtime.String" + + +def test_list_executor_options_not_found(): + with pytest.raises(TVMError, match='Executor "woof" is not defined'): + Executor.list_executor_options("woof") diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py new file mode 100644 index 000000000000..d78b822411bc --- /dev/null +++ b/tests/python/relay/test_runtime.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm import TVMError +from tvm.relay.backend import Runtime + + +def test_create(): + runtime = Runtime("cpp") + assert str(runtime) == "cpp" + + +def test_create_runtime_with_options(): + runtime = Runtime("c", {"system-lib": True}) + assert str(runtime) == "c" + assert runtime["system-lib"] + + +def test_attr_check(): + runtime = Runtime("c", {"system-lib": True}) + assert "woof" not in runtime + assert "system-lib" in runtime + + +def test_create_runtime_not_found(): + with pytest.raises(TVMError, match='Runtime "woof" is not defined'): + Runtime("woof", {}) + + +def test_create_runtime_attr_not_found(): + with pytest.raises(TVMError, match='Attribute "woof" is not available on this Runtime'): + Runtime("c", {"woof": "bark"}) + + +def test_create_runtime_attr_type_incorrect(): + with pytest.raises( + TVMError, + match='Attribute "system-lib" should have type "IntImm"' + ' but instead found "runtime.String"', + ): + Runtime("c", {"system-lib": "woof"}) + + +def test_list_runtimes(): + assert "c" in Runtime.list_runtimes() + + +@pytest.mark.parametrize("runtime", [Runtime("c"), "c"]) +def test_list_runtime_options(runtime): + aot_options = Runtime.list_runtime_options(runtime) + assert "system-lib" in aot_options + assert aot_options["system-lib"] == "IntImm" + + +def test_list_runtime_options_not_found(): + with pytest.raises(TVMError, match='Runtime "woof" is not defined'): + Runtime.list_runtime_options("woof")