diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index f6b699f1e9cc..845ec4b9ba87 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -37,3 +37,4 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("mean", _schedule_reduce) _reg.register_schedule("variance", _schedule_reduce) _reg.register_schedule("nn.cross_entropy", _schedule_reduce) +_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 1c94162d87d9..d55cad7c7a2d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -449,3 +449,12 @@ def cross_entropy_grad(orig, grad): batch_size = take(shape, const(0, dtype='int32'), axis=0) grad = grad / batch_size.astype('float32') return [-grad * y / x, -grad * log(x)] + + +@register_gradient("nn.cross_entropy_with_logits") +def cross_entropy_with_logits_grad(orig, grad): + x, y = orig.args + shape = shape_of(x) + batch_size = take(shape, const(0, dtype='int32'), axis=0) + grad = grad / batch_size.astype('float32') + return [-grad * y, -grad * x] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0043ffae0f61..5786c228abc0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -770,3 +770,12 @@ def schedule_bitserial_dense(attrs, outputs, target): def compute_cross_entropy(attrs, inputs, out_dtype, target): x, y = inputs return [-topi.sum(topi.log(x) * y) / x.shape[0]] + + +reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) + + +@reg.register_compute("nn.cross_entropy_with_logits") +def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target): + x, y = inputs + return [-topi.sum(x * y) / x.shape[0]] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 9ddb3ece4ce2..1f289d1bd27a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets): The computed result. """ return _make.cross_entropy(predictions, targets) + + +def cross_entropy_with_logits(predictions, targets): + """CrossEntropy with logits. + + Parameters + ---------- + predictions : tvm.relay.Expr + The predictions. + + targets : tvm.relay.Expr + The targets. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.cross_entropy_with_logits(predictions, targets) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index dd1b4e532185..416a0d7b543f 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array& types, return true; } -// Positional relay function to create batch_matmul operator used by frontend FFI. +// Positional relay function to create cross_entropy operator used by frontend FFI. Expr MakeCrossEntropy(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy"); return CallNode::make(op, {predictions, targets}, Attrs(), {}); @@ -933,5 +933,28 @@ Do log on the data - do not accept logits. .add_type_rel("CrossEntropy", CrossEntropyRel); +// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. +Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { + static const Op& op = Op::Get("nn.cross_entropy_with_logits"); + return CallNode::make(op, {predictions, targets}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits") +.set_body_typed(MakeCrossEntropyWithLogits); + + +RELAY_REGISTER_OP("nn.cross_entropy_with_logits") +.describe(R"code( +Computes cross entropy given predictions and targets. +Accept logits. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("x", "1D Tensor", "Predictions.") +.add_argument("y", "1D Tensor", "Targets.") +.set_support_level(10) +.add_type_rel("CrossEntropy", CrossEntropyRel); + + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 2592d181240a..7aa9e0bc135f 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -14,15 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + from tvm import relay from tvm.relay.testing import check_grad def test_cross_entropy_grad(): - x = relay.var("x", shape=(1, 5)) - y = relay.var("y", shape=(1, 5)) + x = relay.var("x", shape=(2, 5)) + y = relay.var("y", shape=(2, 5)) check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1) +def test_cross_entropy_with_logits_grad(): + x = relay.var("x", shape=(2, 5)) + y = relay.var("y", shape=(2, 5)) + check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1) + + if __name__ == "__main__": - test_cross_entropy_grad() + pytest.main([__file__])