diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b7b9d23129cd..0231fb640407 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2856,6 +2856,19 @@ def dot(self, inputs, _): lhs, rhs = inputs return _op.sum(_op.multiply(lhs, rhs)) + def mv(self, inputs, _): + lhs, rhs = inputs + + # Convert the 1D matrix (vector) into a 2D matrix with the extra + # dimension=1 + rhs_matrix = _op.transform.expand_dims(rhs, 0) + + # Run multiplication + dense_result = _op.nn.dense(lhs, rhs_matrix, units=None) + + # Chop off the extra result dimension + return _op.transform.squeeze(dense_result) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3081,6 +3094,7 @@ def create_convert_map(self): "aten::roll": self.roll, "aten::einsum": self.einsum, "aten::dot": self.dot, + "aten::mv": self.mv, } def update_convert_map(self, custom_map): diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index b9677d198eba..336575a93e97 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -237,7 +237,7 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s if (it == context->module->source_map->source_map.end()) { LOG(FATAL) << "The source maps are not populated for this module. " << "Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error " - "reporting. " + "reporting.\n" << "Error: " << diagnostic->message; } diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 1421906a3bbb..d40eb8a17c06 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -258,10 +258,12 @@ class TypeSolver::Unifier : public TypeFunctor { if (mismatches.size() != 0) { auto err = Diagnostic::Error(this->span); - err << "The Relay type checker is unable to show the following types match.\n"; - err << "In particular "; + err << "The Relay type checker is unable to show the following types match:\n" + << " " << PrettyPrint(tt1) << "\n" + << " " << PrettyPrint(tt2) << "\n"; + err << "In particular:\n"; for (auto mismatch : mismatches) { - err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) + err << " dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } this->solver_->Emit(err); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07a6171deac5..1f69f20248bc 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4096,5 +4096,15 @@ def test_fn(x): verify_model(test_fn, [x]) +@tvm.testing.uses_gpu +def test_mv(): + def test_fn(m, v): + return m.mv(v) + + verify_model(test_fn, [torch.randn(4, 4), torch.randn(4)]) + verify_model(test_fn, [torch.randn(2, 2), torch.randn(2)]) + verify_model(test_fn, [torch.randn(3, 8), torch.randn(8)]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 50ae3c7cfb0f..9ac04d4933a6 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1418,10 +1418,14 @@ def _body(i, st): start = relay.var("start", shape=(), dtype="int32") body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) + with DiagnosticTesting() as diagnostics: diagnostics.assert_message( - "The Relay type checker is unable to show the following types " - "match.\nIn particular dimension 0 conflicts: 2 does not match 1." + "The Relay type checker is unable to show the following types match:\n" + " Tensor[(2, 1), int32]\n" + " Tensor[(1, 1), int32]\n" + "In particular:\n" + " dimension 0 conflicts: 2 does not match 1." ) func = infer_type(func)