Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple #13671

Merged
merged 1 commit into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
return ::tvm::runtime::make_object<ArrayNode>();
});

struct ShapeTupleObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
hash_reduce(self->size);
for (size_t i = 0; i < self->size; ++i) {
hash_reduce(self->data[i]);
}
}

static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
SEqualReducer equal) {
if (lhs->size != rhs->size) return false;
for (size_t i = 0; i < lhs->size; ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
}
};

TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
.set_creator([](const std::string& blob) {
// Store shape tuple in blob to avoid large integer overflow in JSON.
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
support::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
uint64_t size;
b64strm.Read<uint64_t>(&size);
std::vector<int64_t> data(size);
b64strm.ReadArray(data.data(), size);
ShapeTuple shape(data);
return RefToObjectPtr::Get(shape);
})
.set_repr_bytes([](const Object* n) -> std::string {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
b64strm.Write<uint64_t>(shape->size);
b64strm.WriteArray(shape->data, shape->size);
b64strm.Finish();
return blob;
});

struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

Expand Down
9 changes: 7 additions & 2 deletions src/support/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream {
}
/*! \brief whether current position is end of a base64 stream */
bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }

using dmlc::Stream::Read;
// override read function.
virtual size_t Read(void* ptr, size_t size) {
size_t Read(void* ptr, size_t size) final {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
Expand Down Expand Up @@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream {
class Base64OutStream : public dmlc::Stream {
public:
explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
virtual void Write(const void* ptr, size_t size) {

using dmlc::Stream::Write;

void Write(const void* ptr, size_t size) final {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_container_structural_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents):
assert get_first_mismatch_ensure_symmetry(a, b) is None


@pytest.mark.parametrize(
"contents",
[
[],
[1],
[1, 2, 3],
],
)
def test_shape_tuple_structural_equal_to_self(contents):
a = tvm.runtime.ShapeTuple(list(contents))
b = tvm.runtime.ShapeTuple(list(contents))
assert get_first_mismatch_ensure_symmetry(a, b) is None


@pytest.mark.parametrize(
"a, b, expected_a_path, expected_b_path",
[
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_shape_tuple():
# ShapleTuple vs. ShapeTuple
assert stuple == _container.ShapeTuple(shape)

# test pickle
z = pickle.loads(pickle.dumps(stuple))
assert isinstance(z, tvm.runtime.ShapeTuple)
assert stuple == z


if __name__ == "__main__":
test_string()
Expand Down