From 71986be5438e0f61f78065d7da50e5179934cd31 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 2 Oct 2019 09:35:41 -0700 Subject: [PATCH] address comment --- python/tvm/relay/op/nn/_nn.py | 2 +- src/relay/op/nn/nn.cc | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 12f5dfdf19e8d..8c09390b4deb3 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -753,4 +753,4 @@ def schedule_bitserial_dense(attrs, outputs, target): @reg.register_compute("nn.cross_entropy") def compute_cross_entropy(attrs, inputs, out_dtype, target): x, y = inputs - return [-topi.sum(topi.log(x) * y / x.shape[0])] + return [-topi.sum(topi.log(x) * y) / x.shape[0]] diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d13d6f6f46385..91e9ed1f66bd0 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -758,16 +758,16 @@ bool CrossEntropyRel(const Array& types, if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 2 && y->shape.size() == 2) << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape=, " << x->shape - << "y shape=" << y->shape; + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape=, " << x->shape - << "y shape=" << y->shape; + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape=, " << x->shape - << "y shape=" << y->shape; + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; // assign output type reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype)); return true; @@ -785,7 +785,7 @@ TVM_REGISTER_API("relay.op.nn._make.cross_entropy") RELAY_REGISTER_OP("nn.cross_entropy") -.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE) +.describe(R"code(Computes cross entropy given predictions and targets.)code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("x", "1D Tensor", "Predictions.") .add_argument("y", "1D Tensor", "Targets.")