Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
save

redo max test

save

address comment

fix
  • Loading branch information
MarisaKirisame committed Oct 1, 2019
1 parent 5cc1764 commit 9680232
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
23 changes: 22 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,26 @@
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
from .tensor import (
cos,
exp,
less,
negative,
ones_like,
power,
sin,
zeros_like,
equal,
shape_of,
log)
from .transform import (
broadcast_to_like,
collapse_sum_like,
cast_like,
reshape,
reshape_like,
strided_slice,
take,
tile,
transpose,
where,
Expand Down Expand Up @@ -353,3 +365,12 @@ def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data = orig.args[0]
return [broadcast_to_like(grad, data)]


@register_gradient("nn.cross_entropy")
def cross_entropy_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y / x, -grad * log(x)]
9 changes: 9 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,12 @@ def schedule_bitserial_dense(attrs, outputs, target):


reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)


@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y / x.shape[0])]
4 changes: 4 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,3 +1691,7 @@ def bitserial_dense(data,
"""
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
pack_dtype, out_dtype, unipolar)


def cross_entropy(predictions, targets):
return _make.cross_entropy(predictions, targets)
16 changes: 12 additions & 4 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType())


def _np_randn_from_type(t, scale=1):
return (scale * np.random.randn(*(int(d) for d in t.shape))).astype(t.dtype)
def _np_randn_from_type(t, scale=1, mean=0):
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)


def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3):
def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
"""Perform numerical gradient checking given a relay function.
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
Expand All @@ -86,15 +86,23 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3):
The relative tolerance on difference between numerical and analytical gradients. Note that
this needs to be scaled appropriately relative to the chosen eps.
scale: float
The standard deviation of the inputs.
mean: float
The mean of the inputs.
"""

fwd_func = run_infer_type(func)
bwd_func = run_infer_type(gradient(fwd_func))

if scale is None:
scale = 10 * eps

if inputs is None:
params = fwd_func.params
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
inputs = [_np_randn_from_type(x.checked_type, scale=(10 * eps)) for x in params]
inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
Expand Down
46 changes: 46 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,5 +747,51 @@ are data in batch.
.add_type_rel("BatchMatmul", BatchMatmulRel);


// relay.nn.cross_entropy
bool CrossEntropyRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape=, " << x->shape
<< "y shape=" << y->shape;
// assign output type
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
return true;
}

// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.cross_entropy")
.set_body_typed(MakeCrossEntropy);


RELAY_REGISTER_OP("nn.cross_entropy")
.describe(R"code(Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);


} // namespace relay
} // namespace tvm
28 changes: 28 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing import check_grad


def test_cross_entropy_grad():
x = relay.var("x", shape=(1, 5))
y = relay.var("y", shape=(1, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)


if __name__ == "__main__":
test_cross_entropy_grad()
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_sum_grad():


def test_max_grad():
s = (5, 10)
s = (10, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)

fwd_func = relay.Function([x], z)
check_grad(fwd_func, eps=1e-7, rtol=1)
check_grad(fwd_func, scale=1e-3)


if __name__ == "__main__":
Expand Down

0 comments on commit 9680232

Please sign in to comment.