Skip to content

Commit

Permalink
Add new Metadata classes and base implementation.
Browse files Browse the repository at this point in the history
 * These were autogenerated in the original PR, but checking them in
   as plain code until we can revisit the auto-generator approach.
  • Loading branch information
areusch committed Feb 17, 2022
1 parent 0009a30 commit 2f64ab0
Show file tree
Hide file tree
Showing 7 changed files with 886 additions and 0 deletions.
125 changes: 125 additions & 0 deletions include/tvm/runtime/metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/metadata.h
* \brief Defines types which can be used in Metadata.
*/
#ifndef TVM_RUNTIME_METADATA_H_
#define TVM_RUNTIME_METADATA_H_

#include <inttypes.h>
#ifdef __cplusplus
#include <memory>
#include <string>
#include <vector>
#endif
#include <tvm/runtime/c_runtime_api.h>
#ifdef __cplusplus
#include <tvm/runtime/metadata_base.h>
#endif
#include <tvm/support/span.h>

// Version number recorded in emitted artifacts for runtime checking.
#define TVM_METADATA_VERSION 1
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
#ifdef __cplusplus
extern "C" {
#endif

struct TVMMetadata {
int64_t version;
const struct TVMTensorInfo* inputs;
int64_t num_inputs;
const struct TVMTensorInfo* outputs;
int64_t num_outputs;
const char* mod_name;
};

struct TVMTensorInfo {
const char* name;
const int64_t* shape;
int64_t num_shape;
DLDataType dtype;
};
#ifdef __cplusplus
} // extern "C"
#include <tvm/runtime/object.h>
namespace tvm {
namespace runtime {
namespace metadata {

class Metadata;
class TensorInfo;

class MetadataNode : public MetadataBaseNode {
public:
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.MetadataNode";
std::string get_name() override;
inline int64_t version() const { return int64_t(data_->version); }
inline int64_t num_inputs() const { return data_->num_inputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
inline int64_t num_outputs() const { return data_->num_outputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
const struct ::TVMMetadata* data() const { return data_; }
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);

private:
const struct ::TVMMetadata* data_;
};

class Metadata : public MetadataBase {
public:
explicit Metadata(const struct ::TVMMetadata* data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode);
};

class TensorInfoNode : public MetadataBaseNode {
public:
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.TensorInfoNode";
std::string get_name() override;
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
inline int64_t num_shape() const { return data_->num_shape; }
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
return ::tvm::support::Span<const int64_t, int64_t>(data_->shape,
data_->shape + data_->num_shape);
}
inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); }
const struct ::TVMTensorInfo* data() const { return data_; }
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode);

private:
const struct ::TVMTensorInfo* data_;
};

class TensorInfo : public MetadataBase {
public:
explicit TensorInfo(const struct ::TVMTensorInfo* data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
};

} // namespace metadata
} // namespace runtime
} // namespace tvm
#endif // defined(__cplusplus)

#endif // TVM_RUNTIME_METADATA_H_
176 changes: 176 additions & 0 deletions include/tvm/runtime/metadata_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/metadata_base.h
* \brief Defines types which can be used in Metadata.
*/
#ifndef TVM_RUNTIME_METADATA_BASE_H_
#define TVM_RUNTIME_METADATA_BASE_H_

#include <tvm/ir/expr.h>
#include <tvm/runtime/object.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {
namespace metadata {

class MetadataBaseNode : public ::tvm::runtime::Object {
public:
virtual std::string get_name() = 0;

static constexpr const char* _type_key = "metadata.MetadataBaseNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
};

class MetadataBase : public ::tvm::runtime::ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
};

template <typename C, class Ref>
class ArrayAccessor;

template <typename C, class Ref>
class ArrayIterator {
public:
ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
: index_{index}, parent_{parent} {}

inline Ref operator*() { return (*parent_)[index_]; }

inline ArrayIterator<C, Ref>& operator++() {
if (index_ < parent_->size()) {
index_++;
}

return *this;
}

inline bool operator==(const ArrayIterator<C, Ref>& other) const {
return parent_ == other.parent_ && index_ == other.index_;
}

inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }

// private:
size_t index_;
const ArrayAccessor<C, Ref>* parent_;
};

template <typename C, class Ref>
class ArrayAccessor {
public:
using value_type = Ref;
using iterator = ArrayIterator<C, Ref>;
using const_iterator = iterator;

template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}

inline size_t size() const { return num_data_; }

inline Ref operator[](size_t index) const {
if (index >= num_data_) {
throw std::runtime_error("Index out of range");
}

return Ref(&data_[index]);
}

inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }

inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }

private:
const C* data_;
size_t num_data_;
};

template <>
class ArrayAccessor<const char*, ::tvm::runtime::String> {
public:
using value_type = ::tvm::runtime::String;
using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
using const_iterator = iterator;

ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}

inline size_t size() const { return num_data_; }

inline ::tvm::runtime::String operator[](size_t index) const {
if (index >= num_data_) {
throw std::runtime_error("Index out of range");
}
return ::tvm::runtime::String(data_[index]);
}

inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
}

inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
}

private:
const char** data_;
size_t num_data_;
};

enum MetadataTypeIndex : uint8_t {
kUint64 = 0,
kInt64 = 1,
kBool = 2,
kString = 3,
kHandle = 4,
kMetadata = 5,
};

class MetadataArrayNode : public MetadataBaseNode {
public:
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}

std::string get_name() override;

Array<ObjectRef> array;
MetadataTypeIndex type_index;
const char* struct_name;
static constexpr const char* _type_key = "metadata.MetadataArrayNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
};

class MetadataArray : public MetadataBase {
public:
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
};

} // namespace metadata
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_METADATA_BASE_H_
Loading

0 comments on commit 2f64ab0

Please sign in to comment.