Skip to content

Commit

Permalink
[Relay][Op] Make Type Relation catch more errors (#3899)
Browse files Browse the repository at this point in the history
* save

* init

* move type_relations
  • Loading branch information
MarisaKirisame authored and icemelon committed Sep 6, 2019
1 parent ca0292d commit 19f8c12
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/relay/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <vector>
#include "type_relations.h"
#include "../pass/alter_op_layout.h"

namespace tvm {
Expand Down
13 changes: 6 additions & 7 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file transform.cc
* \brief Transform operators.
*/
Expand Down Expand Up @@ -1541,15 +1541,14 @@ RELAY_REGISTER_OP("squeeze")
.set_attr<TOpPattern>("TOpPattern", kInjective);


// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter);
}

Expr MakeCollapseSumLike(Expr data,
Expand Down Expand Up @@ -1593,7 +1592,7 @@ bool BroadCastToRel(const Array<Type>& types,
if (intt == nullptr) { return false; }
auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type);
return true;
return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
}

Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
Expand Down Expand Up @@ -1632,7 +1631,7 @@ bool BroadCastToLikeRel(const Array<Type>& types,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
}

Expr MakeBroadCastToLike(Expr data,
Expand Down Expand Up @@ -2493,9 +2492,9 @@ RELAY_REGISTER_OP("one_hot")
**off_value** Value to fill at all other positions besides indices.
**depth** Depth of the one-hot dimension.
**axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3)
Expand Down

0 comments on commit 19f8c12

Please sign in to comment.