From 9206d247544dd34d91d33faacdd2adfe3fb83cbb Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Thu, 21 Mar 2024 20:30:35 +0100 Subject: [PATCH 1/3] Backport std::source_location with compiler builtins Signed-off-by: Krzysztof Lecki --- dali/core/source_location_test.cc | 63 +++++++++++++++++++++ dali/core/source_location_test.cu | 86 +++++++++++++++++++++++++++++ include/dali/core/source_location.h | 78 ++++++++++++++++++++++++++ 3 files changed, 227 insertions(+) create mode 100644 dali/core/source_location_test.cc create mode 100644 dali/core/source_location_test.cu create mode 100644 include/dali/core/source_location.h diff --git a/dali/core/source_location_test.cc b/dali/core/source_location_test.cc new file mode 100644 index 00000000000..2704854bade --- /dev/null +++ b/dali/core/source_location_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/core/source_location.h" + + + +namespace dali { +namespace test { + + +source_location get(source_location loc = source_location::current()) { + return loc; +} + +source_location indirect() { + return get(); +} + +void TestFunction() { + source_location current_loc = source_location::current(); + + std::cout << current_loc.source_file() << " " << current_loc.function_name() << " " + << current_loc.line() << std::endl; + + ASSERT_NE(strstr(current_loc.source_file(), "source_location_test.cc"), nullptr); + ASSERT_NE(strstr(current_loc.function_name(), "TestFunction"), nullptr); + ASSERT_NE(current_loc.line(), 0); + + auto returned_loc = indirect(); + + std::cout << returned_loc.source_file() << " " << returned_loc.function_name() << " " + << returned_loc.line() << std::endl; + + ASSERT_STREQ(returned_loc.source_file(), current_loc.source_file()); + ASSERT_NE(strstr(returned_loc.function_name(), "indirect"), nullptr); + ASSERT_GT(current_loc.line(), returned_loc.line()); +} + +TEST(SourceLocation, CurrentLocationTest) { + source_location default_loc; + + ASSERT_STREQ(default_loc.source_file(), ""); + ASSERT_STREQ(default_loc.function_name(), ""); + ASSERT_EQ(default_loc.line(), 0); + + TestFunction(); +} + +} // namespace test +} // namespace dali diff --git a/dali/core/source_location_test.cu b/dali/core/source_location_test.cu new file mode 100644 index 00000000000..fa3cdedd8b5 --- /dev/null +++ b/dali/core/source_location_test.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/core/source_location.h" +#include "dali/test/device_test.h" + + +namespace dali { +namespace test { + +DALI_HOST_DEV int naive_strlen(const char *str) { + int i = 0; + while (str[i] != '\0') { + i++; + } + return i; +} + +DALI_HOST_DEV bool compare(const char *str1, const char *str2, int len) { + for (int i = 0; i < len; i++) { + if (str1[i] != str2[i]) { + return false; + } + } + return true; +} + +DALI_HOST_DEV bool naive_contains(const char *haystack, const char *needle) { + int haystack_len = naive_strlen(haystack); + int needle_len = naive_strlen(needle); + int search_start_pos = 0; + for (int start = 0; start <= haystack_len - needle_len; start++) { + if (compare(haystack + start, needle, needle_len)) { + return true; + } + } + return false; +} + +DALI_HOST_DEV source_location GetDev(source_location loc = source_location::current()) { + return loc; +} + +DALI_HOST_DEV source_location IndirectDev() { + return GetDev(); +} + +DEVICE_TEST(SourceLocationDev, CurrentLocationDevTest, dim3(1), dim3(1)) { + source_location default_loc; + printf("\"%s\":%d in \"%s\"\n", default_loc.source_file(), default_loc.line(), + default_loc.function_name()); + DEV_ASSERT_EQ(default_loc.source_file()[0], '\0'); + DEV_ASSERT_EQ(default_loc.function_name()[0], '\0'); + DEV_ASSERT_EQ(default_loc.line(), 0); + + + source_location current_loc = source_location::current(); + printf("\"%s\":%d in \"%s\"\n", current_loc.source_file(), current_loc.line(), + current_loc.function_name()); + DEV_ASSERT_TRUE(naive_contains(current_loc.source_file(), "source_location_test.cu")); + DEV_ASSERT_TRUE(naive_contains(current_loc.function_name(), "CurrentLocationDevTest")); + DEV_ASSERT_NE(current_loc.line(), 0); + + + auto returned_loc = IndirectDev(); + printf("\"%s\":%d in \"%s\"\n", returned_loc.source_file(), returned_loc.line(), + returned_loc.function_name()); + DEV_ASSERT_TRUE(naive_contains(current_loc.source_file(), returned_loc.source_file())); + DEV_ASSERT_TRUE(naive_contains(returned_loc.function_name(), "IndirectDev")); + DEV_ASSERT_GT(current_loc.line(), returned_loc.line()); +} + +} // namespace test +} // namespace dali diff --git a/include/dali/core/source_location.h b/include/dali/core/source_location.h new file mode 100644 index 00000000000..991a63ad7f1 --- /dev/null +++ b/include/dali/core/source_location.h @@ -0,0 +1,78 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_CORE_SOURCE_LOCATION_H_ +#define DALI_CORE_SOURCE_LOCATION_H_ + +#include +#include +#include +#include "dali/core/host_dev.h" + +namespace dali { + +/** + * @brief Backport of std::source_location from C++20, using compiler __builtins present in earlier + * versions. Allows to replace __FILE__ and __LINE__ macros with function calls. + * + * Call `source_location::current()` to get the current source location. + * In most cases it should be used as a default argument to a function, to record the source + * location of a function call, like so: + * void foo(source_location loc = source_location::current()) { ... } + * + */ +class source_location { + public: + DALI_HOST_DEV constexpr source_location() = default; + DALI_HOST_DEV constexpr source_location(const source_location &) = default; + DALI_HOST_DEV constexpr source_location &operator=(const source_location &) = default; + DALI_HOST_DEV constexpr source_location(source_location &&) = default; + DALI_HOST_DEV constexpr source_location &operator=(source_location &&) = default; + + + DALI_HOST_DEV constexpr const char *source_file() const { + return source_file_; + } + + DALI_HOST_DEV constexpr const char *function_name() const { + return function_name_; + } + + DALI_HOST_DEV constexpr int line() const { + return line_; + } + + /** + * @brief Get the current source location. + * The caller of this function should not override the default arguments. + */ + DALI_HOST_DEV constexpr static source_location current( + const char *source_file = __builtin_FILE(), const char *function_name = __builtin_FUNCTION(), + int line_ = __builtin_LINE()) { + return {source_file, function_name, line_}; + } + + private: + DALI_HOST_DEV constexpr source_location(const char *source_file, const char *function_name, + int line) + : source_file_(source_file), function_name_(function_name), line_(line) {} + const char *source_file_ = ""; + const char *function_name_ = ""; + int line_ = 0; +}; + + +} // namespace dali + +#endif // DALI_CORE_SOURCE_LOCATION_H_ From d63f66e253f611898191b1d56a70e92371f842bb Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Thu, 21 Mar 2024 20:48:05 +0100 Subject: [PATCH 2/3] Try DALI_ENFORCE as a function Signed-off-by: Krzysztof Lecki --- include/dali/core/error_handling.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/include/dali/core/error_handling.h b/include/dali/core/error_handling.h index 7ba9cdfdf5d..b91ead3c647 100644 --- a/include/dali/core/error_handling.h +++ b/include/dali/core/error_handling.h @@ -38,6 +38,7 @@ #include #include "dali/core/common.h" +#include "dali/core/source_location.h" namespace dali { @@ -227,7 +228,16 @@ inline dali::string GetStacktrace() { } \ } while (0) -#define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__) +// #define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__) + +template +void DALI_ENFORCE(T condition, const std::string &error_string = "", + source_location loc = source_location::current()) { + if (!condition) { + throw DALIException(error_string, make_string("[", loc.source_file(), ":", loc.line(), "]"), + dali::GetStacktrace()); + } +} // Enforces that the value of 'var' is in the range [lower, upper) #define DALI_ENFORCE_IN_RANGE(var, lower, upper) \ From f30eff7d864b4fbcd6f3eb4e4feeb1f2c49d2725 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Thu, 21 Mar 2024 20:55:37 +0100 Subject: [PATCH 3/3] more fun with DALI_ENFORCE Signed-off-by: Krzysztof Lecki --- dali/c_api/c_api.cc | 10 +++++----- dali/pipeline/operator/op_spec.h | 1 + include/dali/core/error_handling.h | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dali/c_api/c_api.cc b/dali/c_api/c_api.cc index 5116ae87a43..aed1c5e3f0a 100644 --- a/dali/c_api/c_api.cc +++ b/dali/c_api/c_api.cc @@ -97,7 +97,7 @@ int PopCurrBatchSize(batch_size_map_t *batch_size_map, int max_batch_size, */ dali::InputOperatorNoCopyMode GetExternalSourceCopyMode(unsigned int flags) { dali::InputOperatorNoCopyMode no_copy_mode = dali::InputOperatorNoCopyMode::DEFAULT; - DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)), + dali::DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)), "External Source cannot be forced to use DALI_ext_force_copy and " "DALI_ext_force_no_copy at the same time."); if (flags & DALI_ext_force_copy) { @@ -722,7 +722,7 @@ void daliLoadLibrary(const char* lib_path) { void daliGetReaderMetadata(daliPipelineHandle_t pipe_handle, const char *reader_name, daliReaderMetadata* meta) { - DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL."); + dali::DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL."); dali::Pipeline* pipeline = (*pipe_handle)->pipeline.get(); dali::ReaderMeta returned_meta = pipeline->GetReaderMeta(reader_name); meta->epoch_size = returned_meta.epoch_size; @@ -827,7 +827,7 @@ void daliGetSerializedCheckpoint( daliPipelineHandle_t pipe_handle, const daliExternalContextCheckpoint *external_context, char **checkpoint, size_t *n) { - DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL."); + dali::DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL."); auto &pipeline = (*pipe_handle)->pipeline; dali::ExternalContextCheckpoint ctx{}; if (external_context->pipeline_data.data) { @@ -845,7 +845,7 @@ void daliGetSerializedCheckpoint( std::string cpt = pipeline->SerializedCheckpoint(ctx); *n = cpt.size(); *checkpoint = reinterpret_cast(daliAlloc(cpt.size())); - DALI_ENFORCE(*checkpoint, "Failed to allocate memory"); + dali::DALI_ENFORCE(*checkpoint, "Failed to allocate memory"); memcpy(*checkpoint, cpt.c_str(), *n); } @@ -862,7 +862,7 @@ void daliRestoreFromSerializedCheckpoint( daliPipelineHandle *pipe_handle, const char *checkpoint, size_t n, daliExternalContextCheckpoint *external_context) { - DALI_ENFORCE(external_context != nullptr, + dali::DALI_ENFORCE(external_context != nullptr, "Null external context provided."); auto &pipeline = (*pipe_handle)->pipeline; auto ctx = pipeline->RestoreFromSerializedCheckpoint({checkpoint, n}); diff --git a/dali/pipeline/operator/op_spec.h b/dali/pipeline/operator/op_spec.h index 2c524316f67..3a9c8a325e5 100644 --- a/dali/pipeline/operator/op_spec.h +++ b/dali/pipeline/operator/op_spec.h @@ -188,6 +188,7 @@ class DLL_PUBLIC OpSpec { */ DLL_PUBLIC inline void EnforceNoAliasWithDeprecated(const string& arg_name) { auto set_through = set_through_deprecated_arguments_.find(arg_name); + // TODO: OMG - lazy evaluation of when this is macro ): DALI_ENFORCE( set_through == set_through_deprecated_arguments_.end(), make_string("Operator ", SchemaName(), " got an unexpected '", set_through->second, diff --git a/include/dali/core/error_handling.h b/include/dali/core/error_handling.h index b91ead3c647..43c4dc5aa23 100644 --- a/include/dali/core/error_handling.h +++ b/include/dali/core/error_handling.h @@ -231,7 +231,7 @@ inline dali::string GetStacktrace() { // #define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__) template -void DALI_ENFORCE(T condition, const std::string &error_string = "", +void DALI_ENFORCE(const T &condition, const std::string &error_string = "", source_location loc = source_location::current()) { if (!condition) { throw DALIException(error_string, make_string("[", loc.source_file(), ":", loc.line(), "]"),