Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
* clarify test file brief
* add test case for running status
* add driver stack reference to WaitStatus class

Change-Id: I792742892b761534904816135ae2ffcb3f028b2c
  • Loading branch information
lhutton1 committed Oct 12, 2022
1 parent 3eff57a commit 6154813
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 38 deletions.
27 changes: 15 additions & 12 deletions src/runtime/contrib/ethosn/ethosn_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace ethosn {

namespace dl = ::ethosn::driver_library;

WaitStatus WaitForInference(dl::Inference* inference, int timeout) {
InferenceWaitStatus WaitForInference(dl::Inference* inference, int timeout) {
// Wait for inference to complete
int fd = inference->GetFileDescriptor();
struct pollfd fds;
Expand All @@ -62,26 +62,29 @@ WaitStatus WaitForInference(dl::Inference* inference, int timeout) {
int poll_error_code = errno;

if (poll_result < 0) {
return WaitStatus(WaitErrorCode::Error, "Error while waiting for the inference to complete (" +
std::string(strerror(poll_error_code)) + ")");
return InferenceWaitStatus(InferenceWaitErrorCode::kError,
"Error while waiting for the inference to complete (" +
std::string(strerror(poll_error_code)) + ")");
} else if (poll_result == 0) {
return WaitStatus(WaitErrorCode::Timeout,
"Timed out while waiting for the inference to complete.");
return InferenceWaitStatus(InferenceWaitErrorCode::kTimeout,
"Timed out while waiting for the inference to complete.");
}

// poll_result > 0
dl::InferenceResult npu_result;
if (read(fd, &npu_result, sizeof(npu_result)) != static_cast<ssize_t>(sizeof(npu_result))) {
return WaitStatus(WaitErrorCode::Error, "Failed to read inference result status (" +
std::string(strerror(poll_error_code)) + ")");
return InferenceWaitStatus(
InferenceWaitErrorCode::kError,
"Failed to read inference result status (" + std::string(strerror(poll_error_code)) + ")");
}

if (npu_result != dl::InferenceResult::Completed) {
return WaitStatus(WaitErrorCode::Error, "Inference failed with status " +
std::to_string(static_cast<uint32_t>(npu_result)));
return InferenceWaitStatus(
InferenceWaitErrorCode::kError,
"Inference failed with status " + std::to_string(static_cast<uint32_t>(npu_result)));
}

return WaitStatus(WaitErrorCode::Success);
return InferenceWaitStatus(InferenceWaitErrorCode::kSuccess);
}

void CreateBuffers(std::vector<std::shared_ptr<dl::Buffer>>* fm,
Expand Down Expand Up @@ -135,9 +138,9 @@ bool Inference(tvm::runtime::TVMArgs args, dl::Network* npu,
// Execute the inference.
std::unique_ptr<dl::Inference> inference(
npu->ScheduleInference(ifm_raw, n_inputs, ofm_raw, n_outputs));
WaitStatus result = WaitForInference(inference.get(), 60);
InferenceWaitStatus result = WaitForInference(inference.get(), 60);

if (result.GetErrorCode() != WaitErrorCode::Success) {
if (result.GetErrorCode() != InferenceWaitErrorCode::kSuccess) {
LOG(FATAL) << "An error has occured waiting for the inference of a sub-graph on the NPU: "
<< result.GetErrorDescription();
}
Expand Down
27 changes: 14 additions & 13 deletions src/runtime/contrib/ethosn/ethosn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,32 +110,33 @@ class EthosnModule : public ModuleNode {
/*!
* \brief Error codes for evaluating the result of inference on the NPU.
*/
enum class WaitErrorCode { Success = 0, Timeout = 1, Error = 2 };
enum class InferenceWaitErrorCode { kSuccess = 0, kTimeout = 1, kError = 2 };

/*!
* \brief A helper class holding the status of inference on the NPU and
* associated error message(s) if any occurred.
*
* Similar to the implementation of 'WaitStatus' in the driver stack:
* https://github.com/ARM-software/ethos-n-driver-stack/blob/22.08/armnn-ethos-n-backend/workloads/EthosNPreCompiledWorkload.cpp#L48
*/
class WaitStatus {
class InferenceWaitStatus {
public:
WaitStatus() : error_code_(WaitErrorCode::Success), error_description_("") {}
InferenceWaitStatus() : error_code_(InferenceWaitErrorCode::kSuccess), error_description_("") {}

explicit WaitStatus(WaitErrorCode errorCode, std::string errorDescription = "")
explicit InferenceWaitStatus(InferenceWaitErrorCode errorCode, std::string errorDescription = "")
: error_code_(errorCode), error_description_(errorDescription) {}

WaitStatus(const WaitStatus&) = default;
WaitStatus(WaitStatus&&) = default;
WaitStatus& operator=(const WaitStatus&) = default;
WaitStatus& operator=(WaitStatus&&) = default;

explicit operator bool() const noexcept { return error_code_ == WaitErrorCode::Success; }

WaitErrorCode GetErrorCode() const { return error_code_; }
InferenceWaitStatus(const InferenceWaitStatus&) = default;
InferenceWaitStatus(InferenceWaitStatus&&) = default;
InferenceWaitStatus& operator=(const InferenceWaitStatus&) = default;
InferenceWaitStatus& operator=(InferenceWaitStatus&&) = default;

explicit operator bool() const { return error_code_ == InferenceWaitErrorCode::kSuccess; }
InferenceWaitErrorCode GetErrorCode() const { return error_code_; }
std::string GetErrorDescription() const { return error_description_; }

private:
WaitErrorCode error_code_;
InferenceWaitErrorCode error_code_;
std::string error_description_;
};

Expand Down
40 changes: 27 additions & 13 deletions tests/cpp/runtime/contrib/ethosn/inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file tests/cpp/runtime/contrib/ethosn/inference_test.cc
* \brief Tests to check runtime components used during inference.
* \brief Tests to check Arm(R) Ethos(TM)-N runtime components used during inference.
*/

#ifdef ETHOSN_HW
Expand All @@ -32,27 +32,41 @@ namespace tvm {
namespace runtime {
namespace ethosn {

TEST(WaitForInference, FailedResultRead) {
const int inference_error = 3;
TEST(WaitForInference, InferenceScheduled) {
const int inference_result = 0 /* Scheduled */;
const int timeout = 0;
dl::Inference inference = dl::Inference(inference_error);
WaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), WaitErrorCode::Error);
ICHECK_EQ(result.GetErrorDescription(),
"Failed to read inference result status (No such file or directory)");
dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout);
ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete.");
}

TEST(WaitForInference, InferenceTimeout) {
const int inference_scheduled = 0;
TEST(WaitForInference, InferenceRunning) {
const int inference_result = 1 /* Running */;
const int timeout = 0;
dl::Inference inference = dl::Inference(inference_scheduled);
WaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), WaitErrorCode::Timeout);
dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout);
std::cout << result.GetErrorDescription() << std::endl;
ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete.");
}

TEST(WaitForInference, InferenceError) {
const int inference_result = 3 /* Error */;
const int timeout = 0;

dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kError);
ICHECK_EQ(result.GetErrorDescription(),
"Failed to read inference result status (No such file or directory)");
}

} // namespace ethosn
} // namespace runtime
} // namespace tvm
Expand Down

0 comments on commit 6154813

Please sign in to comment.