-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Introduce Executor and Runtime representations with associated registries #9246
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relay/executor.h | ||
* \brief Object representation of Executor configuration and registry | ||
*/ | ||
#ifndef TVM_RELAY_EXECUTOR_H_ | ||
#define TVM_RELAY_EXECUTOR_H_ | ||
|
||
#include <dmlc/registry.h> | ||
#include <tvm/ir/attrs.h> | ||
#include <tvm/ir/expr.h> | ||
#include <tvm/ir/type.h> | ||
#include <tvm/ir/type_relation.h> | ||
#include <tvm/node/attr_registry_map.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
|
||
template <typename, typename> | ||
class AttrRegistry; | ||
|
||
namespace relay { | ||
|
||
/*! | ||
* \brief Executor information. | ||
* | ||
* This data structure stores the meta-data | ||
* about executors which can be used to pass around information. | ||
* | ||
* \sa Executor | ||
*/ | ||
class ExecutorNode : public Object { | ||
public: | ||
/*! \brief name of the Executor */ | ||
String name; | ||
/* \brief Additional attributes storing meta-data about the Executor. */ | ||
DictAttrs attrs; | ||
|
||
/*! | ||
* \brief Get an attribute. | ||
* | ||
* \param attr_key The attribute key. | ||
* \param default_value The default value if the key does not exist, defaults to nullptr. | ||
* | ||
* \return The result | ||
* | ||
* \tparam TObjectRef the expected object type. | ||
* \throw Error if the key exists but the value does not match TObjectRef | ||
* | ||
* \code | ||
* | ||
* void GetAttrExample(const Executor& executor) { | ||
* auto value = executor->GetAttr<Integer>("AttrKey", 0); | ||
* } | ||
* | ||
* \endcode | ||
*/ | ||
template <typename TObjectRef> | ||
Optional<TObjectRef> GetAttr( | ||
const std::string& attr_key, | ||
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { | ||
return attrs.GetAttr(attr_key, default_value); | ||
} | ||
// variant that uses TObjectRef to enable implicit conversion to default value. | ||
template <typename TObjectRef> | ||
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { | ||
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); | ||
} | ||
|
||
void VisitAttrs(AttrVisitor* v) { | ||
v->Visit("name", &name); | ||
v->Visit("attrs", &attrs); | ||
} | ||
|
||
bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { | ||
return name == other->name && equal.DefEqual(attrs, other->attrs); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce(name); | ||
hash_reduce(attrs); | ||
} | ||
|
||
static constexpr const char* _type_key = "Executor"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference class to ExecutorNode. | ||
* \sa ExecutorNode | ||
*/ | ||
class Executor : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Create a new Executor object using the registry | ||
* \throws Error if name is not registered | ||
* \param name The name of the executor. | ||
* \param attrs Attributes for the executor. | ||
* \return the new Executor object. | ||
*/ | ||
TVM_DLL static Executor Create(String name, Map<String, ObjectRef> attrs); | ||
|
||
/*! | ||
* \brief List all registered Executors | ||
* \return the list of Executors | ||
*/ | ||
TVM_DLL static Array<String> ListExecutors(); | ||
|
||
/*! | ||
* \brief List all options for a specific Executor | ||
* \param name The name of the Executor | ||
* \return Map of option name to type | ||
*/ | ||
TVM_DLL static Map<String, String> ListExecutorOptions(const String& name); | ||
|
||
/*! \brief specify container node */ | ||
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); | ||
|
||
private: | ||
/*! | ||
* \brief Private Constructor | ||
* \param name The executor name | ||
* \param attrs Attributes to apply to this Executor node | ||
*/ | ||
TVM_DLL Executor(String name, DictAttrs attrs) { | ||
auto n = make_object<ExecutorNode>(); | ||
n->name = std::move(name); | ||
n->attrs = std::move(attrs); | ||
data_ = std::move(n); | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief Helper structure to register Executors | ||
* \sa TVM_REGISTER_EXECUTOR | ||
*/ | ||
class ExecutorRegEntry { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any hope of templating away all the gobs of boilerplate? Forgive me if there's an obvious catch I can't see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I started down this path but realised that if we wanted to do this it'd make sense to do it in the context of:
I started doing this and ended up with what was essentially a templated mess - so I'd rather take it later and in isolation 😸 |
||
public: | ||
/*! \brief Set name of the Executor to be the same as registry if it is empty */ | ||
inline ExecutorRegEntry& set_name(); | ||
|
||
/*! | ||
* \brief Register a valid configuration option and its ValueType for validation | ||
* \param key The configuration key | ||
* \tparam ValueType The value type to be registered | ||
*/ | ||
template <typename ValueType> | ||
inline ExecutorRegEntry& add_attr_option(const String& key); | ||
|
||
/*! | ||
* \brief Register a valid configuration option and its ValueType for validation | ||
* \param key The configuration key | ||
* \param default_value The default value of the key | ||
* \tparam ValueType The value type to be registered | ||
*/ | ||
template <typename ValueType> | ||
inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); | ||
|
||
/*! | ||
* \brief Register or get a new entry. | ||
* \param name The name of the operator. | ||
* \return the corresponding entry. | ||
*/ | ||
TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); | ||
|
||
private: | ||
/*! \brief Internal storage of value types */ | ||
struct ValueTypeInfo { | ||
std::string type_key; | ||
uint32_t type_index; | ||
}; | ||
std::unordered_map<std::string, ValueTypeInfo> key2vtype_; | ||
/*! \brief A hash table that stores the default value of each attr */ | ||
std::unordered_map<String, ObjectRef> key2default_; | ||
|
||
/*! \brief Index used for internal lookup of attribute registry */ | ||
uint32_t index_; | ||
|
||
// the name | ||
std::string name; | ||
|
||
/*! \brief Return the index stored in attr registry */ | ||
uint32_t AttrRegistryIndex() const { return index_; } | ||
/*! \brief Return the name stored in attr registry */ | ||
String AttrRegistryName() const { return name; } | ||
|
||
/*! \brief private constructor */ | ||
explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} | ||
|
||
// friend class | ||
template <typename> | ||
friend class AttrRegistryMapContainerMap; | ||
template <typename, typename> | ||
friend class tvm::AttrRegistry; | ||
friend class Executor; | ||
}; | ||
|
||
inline ExecutorRegEntry& ExecutorRegEntry::set_name() { | ||
if (name.empty()) { | ||
name = name; | ||
} | ||
return *this; | ||
} | ||
|
||
template <typename ValueType> | ||
inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { | ||
ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key | ||
<< "' has been set once"; | ||
|
||
using ValueNodeType = typename ValueType::ContainerType; | ||
// NOTE: we could further update the function later. | ||
uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); | ||
|
||
ValueTypeInfo info; | ||
info.type_index = value_type_index; | ||
info.type_key = runtime::Object::TypeIndex2Key(value_type_index); | ||
key2vtype_[key] = info; | ||
return *this; | ||
} | ||
|
||
template <typename ValueType> | ||
inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, | ||
ObjectRef default_value) { | ||
add_attr_option<ValueType>(key); | ||
key2default_[key] = default_value; | ||
return *this; | ||
} | ||
|
||
// internal macros to make executor entries | ||
#define TVM_EXECUTOR_REGISTER_VAR_DEF \ | ||
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor | ||
|
||
/*! | ||
* \def TVM_REGISTER_EXECUTOR | ||
* \brief Register a new executor, or set attribute of the corresponding executor. | ||
* | ||
* \param ExecutorName The name of registry | ||
* | ||
* \code | ||
* | ||
* TVM_REGISTER_EXECUTOR("aot") | ||
* .add_attr_option<String>("my_option"); | ||
* .add_attr_option<String>("my_option_default", String("default")); | ||
* | ||
* \endcode | ||
*/ | ||
#define TVM_REGISTER_EXECUTOR(ExecutorName) \ | ||
TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ | ||
::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName).set_name() | ||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_EXECUTOR_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intended to be sub-classable right? Ie I don't have to reflect everything I care about into the generic attrs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment all the options are enclosed in the attributes, but I left it open for extension if we want to do
class AOTExecutorNode : ExecutorNode
or similar - I see this being more of a thing withRuntime
which may well represent the runtime itself at some point.