From f30eff7d864b4fbcd6f3eb4e4feeb1f2c49d2725 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Thu, 21 Mar 2024 20:55:37 +0100 Subject: [PATCH] more fun with DALI_ENFORCE Signed-off-by: Krzysztof Lecki --- dali/c_api/c_api.cc | 10 +++++----- dali/pipeline/operator/op_spec.h | 1 + include/dali/core/error_handling.h | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dali/c_api/c_api.cc b/dali/c_api/c_api.cc index 5116ae87a43..aed1c5e3f0a 100644 --- a/dali/c_api/c_api.cc +++ b/dali/c_api/c_api.cc @@ -97,7 +97,7 @@ int PopCurrBatchSize(batch_size_map_t *batch_size_map, int max_batch_size, */ dali::InputOperatorNoCopyMode GetExternalSourceCopyMode(unsigned int flags) { dali::InputOperatorNoCopyMode no_copy_mode = dali::InputOperatorNoCopyMode::DEFAULT; - DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)), + dali::DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)), "External Source cannot be forced to use DALI_ext_force_copy and " "DALI_ext_force_no_copy at the same time."); if (flags & DALI_ext_force_copy) { @@ -722,7 +722,7 @@ void daliLoadLibrary(const char* lib_path) { void daliGetReaderMetadata(daliPipelineHandle_t pipe_handle, const char *reader_name, daliReaderMetadata* meta) { - DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL."); + dali::DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL."); dali::Pipeline* pipeline = (*pipe_handle)->pipeline.get(); dali::ReaderMeta returned_meta = pipeline->GetReaderMeta(reader_name); meta->epoch_size = returned_meta.epoch_size; @@ -827,7 +827,7 @@ void daliGetSerializedCheckpoint( daliPipelineHandle_t pipe_handle, const daliExternalContextCheckpoint *external_context, char **checkpoint, size_t *n) { - DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL."); + dali::DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL."); auto &pipeline = (*pipe_handle)->pipeline; dali::ExternalContextCheckpoint ctx{}; if (external_context->pipeline_data.data) { @@ -845,7 +845,7 @@ void daliGetSerializedCheckpoint( std::string cpt = pipeline->SerializedCheckpoint(ctx); *n = cpt.size(); *checkpoint = reinterpret_cast(daliAlloc(cpt.size())); - DALI_ENFORCE(*checkpoint, "Failed to allocate memory"); + dali::DALI_ENFORCE(*checkpoint, "Failed to allocate memory"); memcpy(*checkpoint, cpt.c_str(), *n); } @@ -862,7 +862,7 @@ void daliRestoreFromSerializedCheckpoint( daliPipelineHandle *pipe_handle, const char *checkpoint, size_t n, daliExternalContextCheckpoint *external_context) { - DALI_ENFORCE(external_context != nullptr, + dali::DALI_ENFORCE(external_context != nullptr, "Null external context provided."); auto &pipeline = (*pipe_handle)->pipeline; auto ctx = pipeline->RestoreFromSerializedCheckpoint({checkpoint, n}); diff --git a/dali/pipeline/operator/op_spec.h b/dali/pipeline/operator/op_spec.h index 2c524316f67..3a9c8a325e5 100644 --- a/dali/pipeline/operator/op_spec.h +++ b/dali/pipeline/operator/op_spec.h @@ -188,6 +188,7 @@ class DLL_PUBLIC OpSpec { */ DLL_PUBLIC inline void EnforceNoAliasWithDeprecated(const string& arg_name) { auto set_through = set_through_deprecated_arguments_.find(arg_name); + // TODO: OMG - lazy evaluation of when this is macro ): DALI_ENFORCE( set_through == set_through_deprecated_arguments_.end(), make_string("Operator ", SchemaName(), " got an unexpected '", set_through->second, diff --git a/include/dali/core/error_handling.h b/include/dali/core/error_handling.h index b91ead3c647..43c4dc5aa23 100644 --- a/include/dali/core/error_handling.h +++ b/include/dali/core/error_handling.h @@ -231,7 +231,7 @@ inline dali::string GetStacktrace() { // #define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__) template -void DALI_ENFORCE(T condition, const std::string &error_string = "", +void DALI_ENFORCE(const T &condition, const std::string &error_string = "", source_location loc = source_location::current()) { if (!condition) { throw DALIException(error_string, make_string("[", loc.source_file(), ":", loc.line(), "]"),