Skip to content

Commit

Permalink
fix(torch): errors in input connector are caught correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed May 13, 2024
1 parent 9bdf946 commit ac09c52
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 24 deletions.
105 changes: 81 additions & 24 deletions src/backends/torch/torchinputconns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
#include "utils/utils.hpp"
#include "utils/oatpp.hpp"

// Macro to catch an exception thrown in an omp parallel block
#define CATCH_PARALLEL_EXCEPTION(EXPR, EPTR) \
try \
{ \
EXPR; \
} \
catch (...) \
{ \
_Pragma("omp critical") \
{ \
EPTR = std::current_exception(); \
} \
}

namespace dd
{

Expand Down Expand Up @@ -306,6 +320,7 @@ namespace dd
= _db
&& TorchInputInterface::has_to_create_db(data_vec, _test_split);
bool shouldLoad = !_db || createDb;
std::exception_ptr eptr;

if (shouldLoad)
{
Expand Down Expand Up @@ -370,7 +385,12 @@ namespace dd
// Read data
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, int> &lfile : lfiles)
_dataset.add_image_file(lfile.first, lfile.second);
CATCH_PARALLEL_EXCEPTION(
_dataset.add_image_file(lfile.first, lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

if (!_db)
// in case of db, test sets are already allocated in
Expand Down Expand Up @@ -403,8 +423,12 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, int> &lfile :
tests_lfiles[i])
_test_datasets[i].add_image_file(lfile.first,
lfile.second);
CATCH_PARALLEL_EXCEPTION(_test_datasets[i].add_image_file(
lfile.first, lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

// Write corresp file
std::ofstream correspf(_model_repo + "/" + _correspname,
Expand Down Expand Up @@ -454,9 +478,12 @@ namespace dd

#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile : lfiles)
{
_dataset.add_image_bbox_file(lfile.first, lfile.second);
}
CATCH_PARALLEL_EXCEPTION(
_dataset.add_image_bbox_file(lfile.first, lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

// in case of db, alloc of test sets already done in
// has_to_create_db
Expand All @@ -469,8 +496,13 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile :
tests_lfiles[i])
_test_datasets[i].add_image_bbox_file(lfile.first,
lfile.second);
CATCH_PARALLEL_EXCEPTION(
_test_datasets[i].add_image_bbox_file(lfile.first,
lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);
}
else if (_segmentation) // expects a file list of image filepath
// and target image filepath
Expand Down Expand Up @@ -509,9 +541,12 @@ namespace dd

#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile : lfiles)
{
_dataset.add_image_image_file(lfile.first, lfile.second);
}
CATCH_PARALLEL_EXCEPTION(
_dataset.add_image_image_file(lfile.first, lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

// in case of db, alloc of test sets already done in
// has_to_create_db
Expand All @@ -524,8 +559,13 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile :
tests_lfiles[i])
_test_datasets[i].add_image_image_file(lfile.first,
lfile.second);
CATCH_PARALLEL_EXCEPTION(
_test_datasets[i].add_image_image_file(lfile.first,
lfile.second),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);
}
else if (_ctc)
{
Expand Down Expand Up @@ -576,10 +616,13 @@ namespace dd

#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile : lfiles)
{
_dataset.add_image_text_file(lfile.first, lfile.second,
alphabet, max_ocr_length);
}
CATCH_PARALLEL_EXCEPTION(
_dataset.add_image_text_file(lfile.first, lfile.second,
alphabet, max_ocr_length),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

// in case of db, alloc of test sets already done in
// has_to_create_db
Expand All @@ -592,8 +635,14 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::string> &lfile :
tests_lfiles[i])
_test_datasets[i].add_image_text_file(
lfile.first, lfile.second, alphabet, max_ocr_length);
CATCH_PARALLEL_EXCEPTION(
_test_datasets[i].add_image_text_file(
lfile.first, lfile.second, alphabet,
max_ocr_length),
eptr);

dd_utils::rethrow_exception<InputConnectorBadParamException>(
eptr, this->_logger);

// Write corresp file
std::ofstream correspf(_model_repo + "/" + _correspname,
Expand Down Expand Up @@ -654,9 +703,12 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::vector<double>>
&lfile : lfiles)
{
_dataset.add_image_file(lfile.first, lfile.second);
}
CATCH_PARALLEL_EXCEPTION(
_dataset.add_image_file(lfile.first, lfile.second),
eptr);

dd_utils::rethrow_exception<
InputConnectorBadParamException>(eptr, this->_logger);

// in case of db, alloc of test sets already done in
// has_to_create_db
Expand All @@ -665,8 +717,13 @@ namespace dd
#pragma omp parallel for ordered schedule(static, 1)
for (const std::pair<std::string, std::vector<double>>
&lfile : tests_lfiles[i])
_test_datasets[i].add_image_file(lfile.first,
lfile.second);
CATCH_PARALLEL_EXCEPTION(
_test_datasets[i].add_image_file(lfile.first,
lfile.second),
eptr);

dd_utils::rethrow_exception<
InputConnectorBadParamException>(eptr, this->_logger);
}
else
{
Expand Down
19 changes: 19 additions & 0 deletions src/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "apidata.h"
#include "dd_types.h"
#include "dd_spdlog.h"

namespace dd
{
Expand Down Expand Up @@ -169,6 +170,24 @@ namespace dd
return false;
}

template <class ExceptCls>
inline void rethrow_exception(std::exception_ptr &eptr,
std::shared_ptr<spdlog::logger> logger)
{
try
{
if (eptr)
{
std::rethrow_exception(eptr);
}
}
catch (const std::exception &e)
{
logger->error(std::string("Caught error: ") + e.what());
throw ExceptCls(std::string("Caught error: ") + e.what());
}
}

#ifdef WIN32
inline int my_hardware_concurrency()
{
Expand Down

0 comments on commit ac09c52

Please sign in to comment.