From 7084d5c3907640cfe7d67f4cb0e060af04f08612 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 28 Mar 2020 19:06:51 +0000 Subject: [PATCH] fix --- include/tvm/ir/env_func.h | 3 +- include/tvm/ir/type_relation.h | 1 + tests/cpp/relay_pass_alpha_equal.cc | 67 ------------------------- tests/cpp/relay_pass_type_infer_test.cc | 3 +- tests/cpp/relay_transform_sequential.cc | 3 +- 5 files changed, 7 insertions(+), 70 deletions(-) delete mode 100644 tests/cpp/relay_pass_alpha_equal.cc diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 1064fd1462ded..a0e94e5af916d 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -52,7 +52,8 @@ class EnvFuncNode : public Object { } bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { - return this == other; + // name uniquely identifies the env function. + return name == other->name; } static constexpr const char* _type_key = "EnvFunc"; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 5c87c0757b7ac..592bf25a7270c 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -203,6 +203,7 @@ class TypeRelationNode : public TypeConstraintNode { bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { return + equal(func, other->func) && equal(args, other->args) && equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc deleted file mode 100644 index 0207fca00cf76..0000000000000 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace tvm; - -class TestAlphaEquals { - runtime::PackedFunc *_packed_func; - public: - TestAlphaEquals(const char* func_name) { - _packed_func = new runtime::PackedFunc(); - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - void UpdatePackedFunc(const char* func_name) { - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - bool operator()(ObjectRef input_1, ObjectRef input_2) { - TVMRetValue rv; - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, input_1); - setter(1, input_2); - _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - return bool(rv); - }; - -}; - -TEST(Relay, AlphaTestEmptyTypeNodes) { - auto x = TypeVar("x", kTypeData); - auto y = TypeVar(); - EXPECT_FALSE(relay::AlphaEqual(x, y)); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index f951a8f386a68..3c416918e4414 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -38,7 +39,7 @@ TEST(Relay, SelfReference) { auto type_fx = mod->Lookup("main"); auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); - CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); + CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 756468c9b110f..d974f023d74b6 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -102,7 +103,7 @@ TEST(Relay, Sequential) { auto mod1 = IRModule::FromExpr(expected_func); mod1 = relay::transform::InferType()(mod1); auto expected = mod1->Lookup("main"); - CHECK(relay::AlphaEqual(f, expected)); + CHECK(tvm::StructuralEqual()(f, expected)); } int main(int argc, char** argv) {