Skip to content

Commit

Permalink
[Relay] crossentropy_with_logits and its gradient (apache#4075)
Browse files Browse the repository at this point in the history
* save

* lint
  • Loading branch information
MarisaKirisame authored and kevinthesun committed Oct 30, 2019
1 parent 19f105f commit 3d0af15
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
9 changes: 9 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,12 @@ def cross_entropy_grad(orig, grad):
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y / x, -grad * log(x)]


@register_gradient("nn.cross_entropy_with_logits")
def cross_entropy_with_logits_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y, -grad * x]
9 changes: 9 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,12 @@ def schedule_bitserial_dense(attrs, outputs, target):
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y) / x.shape[0]]


reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)


@reg.register_compute("nn.cross_entropy_with_logits")
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(x * y) / x.shape[0]]
19 changes: 19 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets):
The computed result.
"""
return _make.cross_entropy(predictions, targets)


def cross_entropy_with_logits(predictions, targets):
"""CrossEntropy with logits.
Parameters
----------
predictions : tvm.relay.Expr
The predictions.
targets : tvm.relay.Expr
The targets.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.cross_entropy_with_logits(predictions, targets)
25 changes: 24 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array<Type>& types,
return true;
}

// Positional relay function to create batch_matmul operator used by frontend FFI.
// Positional relay function to create cross_entropy operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
Expand All @@ -933,5 +933,28 @@ Do log on the data - do not accept logits.
.add_type_rel("CrossEntropy", CrossEntropyRel);


// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy_with_logits");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
.set_body_typed(MakeCrossEntropyWithLogits);


RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
.describe(R"code(
Computes cross entropy given predictions and targets.
Accept logits.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);


} // namespace relay
} // namespace tvm
14 changes: 11 additions & 3 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from tvm import relay
from tvm.relay.testing import check_grad


def test_cross_entropy_grad():
x = relay.var("x", shape=(1, 5))
y = relay.var("y", shape=(1, 5))
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)


def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)


if __name__ == "__main__":
test_cross_entropy_grad()
pytest.main([__file__])

0 comments on commit 3d0af15

Please sign in to comment.