forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Introduce Executor and Runtime representations with associate…
…d registries (apache#9246) This is the underpinning Executor/Runtime objects for apache/tvm-rfcs#29, this doesn't change the wiring as yet since that's a pretty big change in itself. Most of this patch is TVM boilerplate, which I've tried to minimize - there's maybe some future work to decide how to more easily define some of these things.
- Loading branch information
Showing
11 changed files
with
1,175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relay/executor.h | ||
* \brief Object representation of Executor configuration and registry | ||
*/ | ||
#ifndef TVM_RELAY_EXECUTOR_H_ | ||
#define TVM_RELAY_EXECUTOR_H_ | ||
|
||
#include <dmlc/registry.h> | ||
#include <tvm/ir/attrs.h> | ||
#include <tvm/ir/expr.h> | ||
#include <tvm/ir/type.h> | ||
#include <tvm/ir/type_relation.h> | ||
#include <tvm/node/attr_registry_map.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
|
||
template <typename, typename> | ||
class AttrRegistry; | ||
|
||
namespace relay { | ||
|
||
/*! | ||
* \brief Executor information. | ||
* | ||
* This data structure stores the meta-data | ||
* about executors which can be used to pass around information. | ||
* | ||
* \sa Executor | ||
*/ | ||
class ExecutorNode : public Object { | ||
public: | ||
/*! \brief name of the Executor */ | ||
String name; | ||
/* \brief Additional attributes storing meta-data about the Executor. */ | ||
DictAttrs attrs; | ||
|
||
/*! | ||
* \brief Get an attribute. | ||
* | ||
* \param attr_key The attribute key. | ||
* \param default_value The default value if the key does not exist, defaults to nullptr. | ||
* | ||
* \return The result | ||
* | ||
* \tparam TObjectRef the expected object type. | ||
* \throw Error if the key exists but the value does not match TObjectRef | ||
* | ||
* \code | ||
* | ||
* void GetAttrExample(const Executor& executor) { | ||
* auto value = executor->GetAttr<Integer>("AttrKey", 0); | ||
* } | ||
* | ||
* \endcode | ||
*/ | ||
template <typename TObjectRef> | ||
Optional<TObjectRef> GetAttr( | ||
const std::string& attr_key, | ||
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { | ||
return attrs.GetAttr(attr_key, default_value); | ||
} | ||
// variant that uses TObjectRef to enable implicit conversion to default value. | ||
template <typename TObjectRef> | ||
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { | ||
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); | ||
} | ||
|
||
void VisitAttrs(AttrVisitor* v) { | ||
v->Visit("name", &name); | ||
v->Visit("attrs", &attrs); | ||
} | ||
|
||
bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { | ||
return name == other->name && equal.DefEqual(attrs, other->attrs); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce(name); | ||
hash_reduce(attrs); | ||
} | ||
|
||
static constexpr const char* _type_key = "Executor"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference class to ExecutorNode. | ||
* \sa ExecutorNode | ||
*/ | ||
class Executor : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Create a new Executor object using the registry | ||
* \throws Error if name is not registered | ||
* \param name The name of the executor. | ||
* \param attrs Attributes for the executor. | ||
* \return the new Executor object. | ||
*/ | ||
TVM_DLL static Executor Create(String name, Map<String, ObjectRef> attrs); | ||
|
||
/*! | ||
* \brief List all registered Executors | ||
* \return the list of Executors | ||
*/ | ||
TVM_DLL static Array<String> ListExecutors(); | ||
|
||
/*! | ||
* \brief List all options for a specific Executor | ||
* \param name The name of the Executor | ||
* \return Map of option name to type | ||
*/ | ||
TVM_DLL static Map<String, String> ListExecutorOptions(const String& name); | ||
|
||
/*! \brief specify container node */ | ||
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); | ||
|
||
private: | ||
/*! | ||
* \brief Private Constructor | ||
* \param name The executor name | ||
* \param attrs Attributes to apply to this Executor node | ||
*/ | ||
TVM_DLL Executor(String name, DictAttrs attrs) { | ||
auto n = make_object<ExecutorNode>(); | ||
n->name = std::move(name); | ||
n->attrs = std::move(attrs); | ||
data_ = std::move(n); | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief Helper structure to register Executors | ||
* \sa TVM_REGISTER_EXECUTOR | ||
*/ | ||
class ExecutorRegEntry { | ||
public: | ||
/*! \brief Set name of the Executor to be the same as registry if it is empty */ | ||
inline ExecutorRegEntry& set_name(); | ||
|
||
/*! | ||
* \brief Register a valid configuration option and its ValueType for validation | ||
* \param key The configuration key | ||
* \tparam ValueType The value type to be registered | ||
*/ | ||
template <typename ValueType> | ||
inline ExecutorRegEntry& add_attr_option(const String& key); | ||
|
||
/*! | ||
* \brief Register a valid configuration option and its ValueType for validation | ||
* \param key The configuration key | ||
* \param default_value The default value of the key | ||
* \tparam ValueType The value type to be registered | ||
*/ | ||
template <typename ValueType> | ||
inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); | ||
|
||
/*! | ||
* \brief Register or get a new entry. | ||
* \param name The name of the operator. | ||
* \return the corresponding entry. | ||
*/ | ||
TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); | ||
|
||
private: | ||
/*! \brief Internal storage of value types */ | ||
struct ValueTypeInfo { | ||
std::string type_key; | ||
uint32_t type_index; | ||
}; | ||
std::unordered_map<std::string, ValueTypeInfo> key2vtype_; | ||
/*! \brief A hash table that stores the default value of each attr */ | ||
std::unordered_map<String, ObjectRef> key2default_; | ||
|
||
/*! \brief Index used for internal lookup of attribute registry */ | ||
uint32_t index_; | ||
|
||
// the name | ||
std::string name; | ||
|
||
/*! \brief Return the index stored in attr registry */ | ||
uint32_t AttrRegistryIndex() const { return index_; } | ||
/*! \brief Return the name stored in attr registry */ | ||
String AttrRegistryName() const { return name; } | ||
|
||
/*! \brief private constructor */ | ||
explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} | ||
|
||
// friend class | ||
template <typename> | ||
friend class AttrRegistryMapContainerMap; | ||
template <typename, typename> | ||
friend class tvm::AttrRegistry; | ||
friend class Executor; | ||
}; | ||
|
||
inline ExecutorRegEntry& ExecutorRegEntry::set_name() { | ||
if (name.empty()) { | ||
name = name; | ||
} | ||
return *this; | ||
} | ||
|
||
template <typename ValueType> | ||
inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { | ||
ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key | ||
<< "' has been set once"; | ||
|
||
using ValueNodeType = typename ValueType::ContainerType; | ||
// NOTE: we could further update the function later. | ||
uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); | ||
|
||
ValueTypeInfo info; | ||
info.type_index = value_type_index; | ||
info.type_key = runtime::Object::TypeIndex2Key(value_type_index); | ||
key2vtype_[key] = info; | ||
return *this; | ||
} | ||
|
||
template <typename ValueType> | ||
inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, | ||
ObjectRef default_value) { | ||
add_attr_option<ValueType>(key); | ||
key2default_[key] = default_value; | ||
return *this; | ||
} | ||
|
||
// internal macros to make executor entries | ||
#define TVM_EXECUTOR_REGISTER_VAR_DEF \ | ||
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor | ||
|
||
/*! | ||
* \def TVM_REGISTER_EXECUTOR | ||
* \brief Register a new executor, or set attribute of the corresponding executor. | ||
* | ||
* \param ExecutorName The name of registry | ||
* | ||
* \code | ||
* | ||
* TVM_REGISTER_EXECUTOR("aot") | ||
* .add_attr_option<String>("my_option"); | ||
* .add_attr_option<String>("my_option_default", String("default")); | ||
* | ||
* \endcode | ||
*/ | ||
#define TVM_REGISTER_EXECUTOR(ExecutorName) \ | ||
TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ | ||
::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName).set_name() | ||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_EXECUTOR_H_ |
Oops, something went wrong.