Skip to content

Commit

Permalink
Merge pull request microsoft#23 from NonStatic2014/bohu/generate_prot…
Browse files Browse the repository at this point in the history
…obuf_status

Create a mapping from ONNX Runtime Status to Protobuf Status
  • Loading branch information
NonStatic2014 authored Apr 3, 2019
2 parents 094aca2 + f44e907 commit 05847b0
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 9 deletions.
4 changes: 3 additions & 1 deletion cmake/onnxruntime_hosting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ file(GLOB_RECURSE onnxruntime_hosting_lib_srcs
"${ONNXRUNTIME_ROOT}/hosting/environment.cc"
"${ONNXRUNTIME_ROOT}/hosting/executor.cc"
"${ONNXRUNTIME_ROOT}/hosting/converter.cc"
)
"${ONNXRUNTIME_ROOT}/hosting/util.cc"
)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/json_handling.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/predict_request_handler.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/executor.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/converter.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/util.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

Expand Down
14 changes: 6 additions & 8 deletions onnxruntime/hosting/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "converter.h"
#include "executor.h"
#include "util.h"

namespace onnxruntime {
namespace hosting {
Expand Down Expand Up @@ -53,8 +54,7 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
LOGS(*logger, ERROR) << "GetSizeInBytesFromTensorProto() FAILED! Input name: " << input_name
<< " Error code: " << status.Code()
<< ". Error Message: " << status.ErrorMessage();
return protobufutil::Status(static_cast<protobufutil::error::Code>(status.Code()),
"GetSizeInBytesFromTensorProto() FAILED: " + status.ErrorMessage());
return GenerateProtoBufStatus(status, "GetSizeInBytesFromTensorProto() FAILED: " + status.ErrorMessage());
}

std::unique_ptr<char[]> data(new char[cpu_tensor_length]);
Expand All @@ -74,8 +74,7 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
LOGS(*logger, ERROR) << "TensorProtoToMLValue() FAILED! Input name: " << input_name
<< " Error code: " << status.Code()
<< ". Error Message: " << status.ErrorMessage();
return protobufutil::Status(static_cast<protobufutil::error::Code>(status.Code()),
"TensorProtoToMLValue() FAILED: " + status.ErrorMessage());
return GenerateProtoBufStatus(status, "TensorProtoToMLValue() FAILED:" + status.ErrorMessage());
}

auto insertion_result = name_ml_value_map.insert(std::make_pair(input_name, ml_value));
Expand Down Expand Up @@ -104,19 +103,18 @@ protobufutil::Status Executor::Predict(const std::string& model_name, const std:
LOGS(*logger, ERROR) << "Run() FAILED!"
<< " Error code: " << status.Code()
<< ". Error Message: " << status.ErrorMessage();
return protobufutil::Status(static_cast<protobufutil::error::Code>(status.Code()),
"Run() FAILED!" + status.ErrorMessage());
return GenerateProtoBufStatus(status, "Run() FAILED: " + status.ErrorMessage());
}

// Build the response
for (size_t i = 0; i < outputs.size(); ++i) {
onnx::TensorProto output_tensor{};
status = MLValueToTensorProto(outputs[i], using_raw_data, std::move(logger), output_tensor);
if (!status.IsOK()) {
LOGS(*logger, ERROR) << "MLValue2TensorProto() FAILED! Output name: " << output_names[i]
LOGS(*logger, ERROR) << "MLValueToTensorProto() FAILED! Output name: " << output_names[i]
<< " Error code: " << status.Code()
<< ". Error Message: " << status.ErrorMessage();
return protobufutil::Status(static_cast<protobufutil::error::Code>(status.Code()), "MLValue2TensorProto() FAILED!");
return GenerateProtoBufStatus(status, "MLValueToTensorProto() FAILED: " + status.ErrorMessage());
}

response.mutable_outputs()->insert({output_names[i], output_tensor});
Expand Down
48 changes: 48 additions & 0 deletions onnxruntime/hosting/util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <sstream>
#include <google/protobuf/stubs/status.h>

#include "core/common/status.h"
#include "util.h"

namespace onnxruntime {
namespace hosting {

namespace protobufutil = google::protobuf::util;

protobufutil::Status GenerateProtoBufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message) {
protobufutil::error::Code code = protobufutil::error::Code::UNKNOWN;
switch (onnx_status.Code()) {
case onnxruntime::common::StatusCode::OK:
case onnxruntime::common::StatusCode::MODEL_LOADED:
code = protobufutil::error::Code::OK;
break;
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
case onnxruntime::common::StatusCode::INVALID_GRAPH:
case onnxruntime::common::StatusCode::SHAPE_INFERENCE_NOT_REGISTERED:
case onnxruntime::common::StatusCode::REQUIREMENT_NOT_REGISTERED:
case onnxruntime::common::StatusCode::NO_SUCHFILE:
case onnxruntime::common::StatusCode::NO_MODEL:
code = protobufutil::error::Code::INVALID_ARGUMENT;
break;
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
code = protobufutil::error::Code::UNIMPLEMENTED;
break;
case onnxruntime::common::StatusCode::FAIL:
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
code = protobufutil::error::Code::INTERNAL;
break;
default:
code = protobufutil::error::Code::UNKNOWN;
}

std::ostringstream oss;
oss << "ONNX Runtime Status Code: " << onnx_status.Code() << ". " << message;
return protobufutil::Status(code, oss.str());
}

} // namespace hosting
} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/hosting/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef ONNXRUNTIME_HOSTING_UTIL_H
#define ONNXRUNTIME_HOSTING_UTIL_H

#include <google/protobuf/stubs/status.h>

#include "core/common/status.h"

namespace onnxruntime {
namespace hosting {

// Generate proper protobuf status from ONNX Runtime status
google::protobuf::util::Status GenerateProtoBufStatus(const onnxruntime::common::Status& onnx_status, const std::string& message);

} // namespace hosting
} // namespace onnxruntime

#endif //ONNXRUNTIME_HOSTING_UTIL_H

0 comments on commit 05847b0

Please sign in to comment.