Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Oct 2, 2019
1 parent 9680232 commit 71986be
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,4 +753,4 @@ def schedule_bitserial_dense(attrs, outputs, target):
@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y / x.shape[0])]
return [-topi.sum(topi.log(x) * y) / x.shape[0]]
14 changes: 7 additions & 7 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -758,16 +758,16 @@ bool CrossEntropyRel(const Array<Type>& types,
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
// assign output type
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
return true;
Expand All @@ -785,7 +785,7 @@ TVM_REGISTER_API("relay.op.nn._make.cross_entropy")


RELAY_REGISTER_OP("nn.cross_entropy")
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
.describe(R"code(Computes cross entropy given predictions and targets.)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
Expand Down

0 comments on commit 71986be

Please sign in to comment.