-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new Metadata classes and base implementation.
* These were autogenerated in the original PR, but checking them in as plain code until we can revisit the auto-generator approach.
- Loading branch information
Showing
7 changed files
with
886 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/runtime/metadata.h | ||
* \brief Defines types which can be used in Metadata. | ||
*/ | ||
#ifndef TVM_RUNTIME_METADATA_H_ | ||
#define TVM_RUNTIME_METADATA_H_ | ||
|
||
#include <inttypes.h> | ||
#ifdef __cplusplus | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#endif | ||
#include <tvm/runtime/c_runtime_api.h> | ||
#ifdef __cplusplus | ||
#include <tvm/runtime/metadata_base.h> | ||
#endif | ||
#include <tvm/support/span.h> | ||
|
||
// Version number recorded in emitted artifacts for runtime checking. | ||
#define TVM_METADATA_VERSION 1 | ||
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; | ||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
struct TVMMetadata { | ||
int64_t version; | ||
const struct TVMTensorInfo* inputs; | ||
int64_t num_inputs; | ||
const struct TVMTensorInfo* outputs; | ||
int64_t num_outputs; | ||
const char* mod_name; | ||
}; | ||
|
||
struct TVMTensorInfo { | ||
const char* name; | ||
const int64_t* shape; | ||
int64_t num_shape; | ||
DLDataType dtype; | ||
}; | ||
#ifdef __cplusplus | ||
} // extern "C" | ||
#include <tvm/runtime/object.h> | ||
namespace tvm { | ||
namespace runtime { | ||
namespace metadata { | ||
|
||
class Metadata; | ||
class TensorInfo; | ||
|
||
class MetadataNode : public MetadataBaseNode { | ||
public: | ||
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} | ||
static constexpr const char* _type_key = "metadata.MetadataNode"; | ||
std::string get_name() override; | ||
inline int64_t version() const { return int64_t(data_->version); } | ||
inline int64_t num_inputs() const { return data_->num_inputs; } | ||
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs(); | ||
inline int64_t num_outputs() const { return data_->num_outputs; } | ||
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs(); | ||
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } | ||
const struct ::TVMMetadata* data() const { return data_; } | ||
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); | ||
|
||
private: | ||
const struct ::TVMMetadata* data_; | ||
}; | ||
|
||
class Metadata : public MetadataBase { | ||
public: | ||
explicit Metadata(const struct ::TVMMetadata* data); | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); | ||
}; | ||
|
||
class TensorInfoNode : public MetadataBaseNode { | ||
public: | ||
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} | ||
static constexpr const char* _type_key = "metadata.TensorInfoNode"; | ||
std::string get_name() override; | ||
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } | ||
inline int64_t num_shape() const { return data_->num_shape; } | ||
inline ::tvm::support::Span<const int64_t, int64_t> shape() const { | ||
return ::tvm::support::Span<const int64_t, int64_t>(data_->shape, | ||
data_->shape + data_->num_shape); | ||
} | ||
inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } | ||
const struct ::TVMTensorInfo* data() const { return data_; } | ||
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); | ||
|
||
private: | ||
const struct ::TVMTensorInfo* data_; | ||
}; | ||
|
||
class TensorInfo : public MetadataBase { | ||
public: | ||
explicit TensorInfo(const struct ::TVMTensorInfo* data); | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); | ||
}; | ||
|
||
} // namespace metadata | ||
} // namespace runtime | ||
} // namespace tvm | ||
#endif // defined(__cplusplus) | ||
|
||
#endif // TVM_RUNTIME_METADATA_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/runtime/metadata_base.h | ||
* \brief Defines types which can be used in Metadata. | ||
*/ | ||
#ifndef TVM_RUNTIME_METADATA_BASE_H_ | ||
#define TVM_RUNTIME_METADATA_BASE_H_ | ||
|
||
#include <tvm/ir/expr.h> | ||
#include <tvm/runtime/object.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
namespace metadata { | ||
|
||
class MetadataBaseNode : public ::tvm::runtime::Object { | ||
public: | ||
virtual std::string get_name() = 0; | ||
|
||
static constexpr const char* _type_key = "metadata.MetadataBaseNode"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); | ||
}; | ||
|
||
class MetadataBase : public ::tvm::runtime::ObjectRef { | ||
public: | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); | ||
}; | ||
|
||
template <typename C, class Ref> | ||
class ArrayAccessor; | ||
|
||
template <typename C, class Ref> | ||
class ArrayIterator { | ||
public: | ||
ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent) | ||
: index_{index}, parent_{parent} {} | ||
|
||
inline Ref operator*() { return (*parent_)[index_]; } | ||
|
||
inline ArrayIterator<C, Ref>& operator++() { | ||
if (index_ < parent_->size()) { | ||
index_++; | ||
} | ||
|
||
return *this; | ||
} | ||
|
||
inline bool operator==(const ArrayIterator<C, Ref>& other) const { | ||
return parent_ == other.parent_ && index_ == other.index_; | ||
} | ||
|
||
inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); } | ||
|
||
// private: | ||
size_t index_; | ||
const ArrayAccessor<C, Ref>* parent_; | ||
}; | ||
|
||
template <typename C, class Ref> | ||
class ArrayAccessor { | ||
public: | ||
using value_type = Ref; | ||
using iterator = ArrayIterator<C, Ref>; | ||
using const_iterator = iterator; | ||
|
||
template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type> | ||
ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} | ||
|
||
inline size_t size() const { return num_data_; } | ||
|
||
inline Ref operator[](size_t index) const { | ||
if (index >= num_data_) { | ||
throw std::runtime_error("Index out of range"); | ||
} | ||
|
||
return Ref(&data_[index]); | ||
} | ||
|
||
inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; } | ||
|
||
inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; } | ||
|
||
private: | ||
const C* data_; | ||
size_t num_data_; | ||
}; | ||
|
||
template <> | ||
class ArrayAccessor<const char*, ::tvm::runtime::String> { | ||
public: | ||
using value_type = ::tvm::runtime::String; | ||
using iterator = ArrayIterator<const char*, ::tvm::runtime::String>; | ||
using const_iterator = iterator; | ||
|
||
ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} | ||
|
||
inline size_t size() const { return num_data_; } | ||
|
||
inline ::tvm::runtime::String operator[](size_t index) const { | ||
if (index >= num_data_) { | ||
throw std::runtime_error("Index out of range"); | ||
} | ||
return ::tvm::runtime::String(data_[index]); | ||
} | ||
|
||
inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const { | ||
return ArrayIterator<const char*, ::tvm::runtime::String>{0, this}; | ||
} | ||
|
||
inline ArrayIterator<const char*, ::tvm::runtime::String> end() const { | ||
return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this}; | ||
} | ||
|
||
private: | ||
const char** data_; | ||
size_t num_data_; | ||
}; | ||
|
||
enum MetadataTypeIndex : uint8_t { | ||
kUint64 = 0, | ||
kInt64 = 1, | ||
kBool = 2, | ||
kString = 3, | ||
kHandle = 4, | ||
kMetadata = 5, | ||
}; | ||
|
||
class MetadataArrayNode : public MetadataBaseNode { | ||
public: | ||
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name) | ||
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} | ||
|
||
std::string get_name() override; | ||
|
||
Array<ObjectRef> array; | ||
MetadataTypeIndex type_index; | ||
const char* struct_name; | ||
static constexpr const char* _type_key = "metadata.MetadataArrayNode"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); | ||
}; | ||
|
||
class MetadataArray : public MetadataBase { | ||
public: | ||
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name); | ||
|
||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); | ||
}; | ||
|
||
} // namespace metadata | ||
} // namespace runtime | ||
} // namespace tvm | ||
|
||
#endif // TVM_RUNTIME_METADATA_BASE_H_ |
Oops, something went wrong.