Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport std::source_location with compiler builtins and make DALI_ENFORCE a function #5389

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ int PopCurrBatchSize(batch_size_map_t *batch_size_map, int max_batch_size,
*/
dali::InputOperatorNoCopyMode GetExternalSourceCopyMode(unsigned int flags) {
dali::InputOperatorNoCopyMode no_copy_mode = dali::InputOperatorNoCopyMode::DEFAULT;
DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)),
dali::DALI_ENFORCE(!((flags & DALI_ext_force_copy) && (flags & DALI_ext_force_no_copy)),
"External Source cannot be forced to use DALI_ext_force_copy and "
"DALI_ext_force_no_copy at the same time.");
if (flags & DALI_ext_force_copy) {
Expand Down Expand Up @@ -722,7 +722,7 @@ void daliLoadLibrary(const char* lib_path) {

void daliGetReaderMetadata(daliPipelineHandle_t pipe_handle, const char *reader_name,
daliReaderMetadata* meta) {
DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL.");
dali::DALI_ENFORCE(meta, "Provided pointer to meta cannot be NULL.");
dali::Pipeline* pipeline = (*pipe_handle)->pipeline.get();
dali::ReaderMeta returned_meta = pipeline->GetReaderMeta(reader_name);
meta->epoch_size = returned_meta.epoch_size;
Expand Down Expand Up @@ -827,7 +827,7 @@ void daliGetSerializedCheckpoint(
daliPipelineHandle_t pipe_handle,
const daliExternalContextCheckpoint *external_context,
char **checkpoint, size_t *n) {
DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL.");
dali::DALI_ENFORCE(external_context, "Provided pointer to external context cannot be NULL.");
auto &pipeline = (*pipe_handle)->pipeline;
dali::ExternalContextCheckpoint ctx{};
if (external_context->pipeline_data.data) {
Expand All @@ -845,7 +845,7 @@ void daliGetSerializedCheckpoint(
std::string cpt = pipeline->SerializedCheckpoint(ctx);
*n = cpt.size();
*checkpoint = reinterpret_cast<char *>(daliAlloc(cpt.size()));
DALI_ENFORCE(*checkpoint, "Failed to allocate memory");
dali::DALI_ENFORCE(*checkpoint, "Failed to allocate memory");
memcpy(*checkpoint, cpt.c_str(), *n);
}

Expand All @@ -862,7 +862,7 @@ void daliRestoreFromSerializedCheckpoint(
daliPipelineHandle *pipe_handle,
const char *checkpoint, size_t n,
daliExternalContextCheckpoint *external_context) {
DALI_ENFORCE(external_context != nullptr,
dali::DALI_ENFORCE(external_context != nullptr,
"Null external context provided.");
auto &pipeline = (*pipe_handle)->pipeline;
auto ctx = pipeline->RestoreFromSerializedCheckpoint({checkpoint, n});
Expand Down
63 changes: 63 additions & 0 deletions dali/core/source_location_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "dali/core/source_location.h"



namespace dali {
namespace test {


source_location get(source_location loc = source_location::current()) {
return loc;
}

source_location indirect() {
return get();
}

void TestFunction() {
source_location current_loc = source_location::current();

std::cout << current_loc.source_file() << " " << current_loc.function_name() << " "
<< current_loc.line() << std::endl;

ASSERT_NE(strstr(current_loc.source_file(), "source_location_test.cc"), nullptr);
ASSERT_NE(strstr(current_loc.function_name(), "TestFunction"), nullptr);
ASSERT_NE(current_loc.line(), 0);

auto returned_loc = indirect();

std::cout << returned_loc.source_file() << " " << returned_loc.function_name() << " "
<< returned_loc.line() << std::endl;

ASSERT_STREQ(returned_loc.source_file(), current_loc.source_file());
ASSERT_NE(strstr(returned_loc.function_name(), "indirect"), nullptr);
ASSERT_GT(current_loc.line(), returned_loc.line());
}

TEST(SourceLocation, CurrentLocationTest) {
source_location default_loc;

ASSERT_STREQ(default_loc.source_file(), "");
ASSERT_STREQ(default_loc.function_name(), "");
ASSERT_EQ(default_loc.line(), 0);

TestFunction();
}

} // namespace test
} // namespace dali
86 changes: 86 additions & 0 deletions dali/core/source_location_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include "dali/core/source_location.h"
#include "dali/test/device_test.h"


namespace dali {
namespace test {

DALI_HOST_DEV int naive_strlen(const char *str) {
int i = 0;
while (str[i] != '\0') {
i++;
}
return i;
}

DALI_HOST_DEV bool compare(const char *str1, const char *str2, int len) {
for (int i = 0; i < len; i++) {
if (str1[i] != str2[i]) {
return false;
}
}
return true;
}

DALI_HOST_DEV bool naive_contains(const char *haystack, const char *needle) {
int haystack_len = naive_strlen(haystack);
int needle_len = naive_strlen(needle);
int search_start_pos = 0;
for (int start = 0; start <= haystack_len - needle_len; start++) {
if (compare(haystack + start, needle, needle_len)) {
return true;
}
}
return false;
}

DALI_HOST_DEV source_location GetDev(source_location loc = source_location::current()) {
return loc;
}

DALI_HOST_DEV source_location IndirectDev() {
return GetDev();
}

DEVICE_TEST(SourceLocationDev, CurrentLocationDevTest, dim3(1), dim3(1)) {
source_location default_loc;
printf("\"%s\":%d in \"%s\"\n", default_loc.source_file(), default_loc.line(),
default_loc.function_name());
DEV_ASSERT_EQ(default_loc.source_file()[0], '\0');
DEV_ASSERT_EQ(default_loc.function_name()[0], '\0');
DEV_ASSERT_EQ(default_loc.line(), 0);


source_location current_loc = source_location::current();
printf("\"%s\":%d in \"%s\"\n", current_loc.source_file(), current_loc.line(),
current_loc.function_name());
DEV_ASSERT_TRUE(naive_contains(current_loc.source_file(), "source_location_test.cu"));
DEV_ASSERT_TRUE(naive_contains(current_loc.function_name(), "CurrentLocationDevTest"));
DEV_ASSERT_NE(current_loc.line(), 0);


auto returned_loc = IndirectDev();
printf("\"%s\":%d in \"%s\"\n", returned_loc.source_file(), returned_loc.line(),
returned_loc.function_name());
DEV_ASSERT_TRUE(naive_contains(current_loc.source_file(), returned_loc.source_file()));
DEV_ASSERT_TRUE(naive_contains(returned_loc.function_name(), "IndirectDev"));
DEV_ASSERT_GT(current_loc.line(), returned_loc.line());
}

} // namespace test
} // namespace dali
1 change: 1 addition & 0 deletions dali/pipeline/operator/op_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class DLL_PUBLIC OpSpec {
*/
DLL_PUBLIC inline void EnforceNoAliasWithDeprecated(const string& arg_name) {
auto set_through = set_through_deprecated_arguments_.find(arg_name);
// TODO: OMG - lazy evaluation of when this is macro ):
DALI_ENFORCE(
set_through == set_through_deprecated_arguments_.end(),
make_string("Operator ", SchemaName(), " got an unexpected '", set_through->second,
Expand Down
12 changes: 11 additions & 1 deletion include/dali/core/error_handling.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <utility>

#include "dali/core/common.h"
#include "dali/core/source_location.h"

namespace dali {

Expand Down Expand Up @@ -227,7 +228,16 @@ inline dali::string GetStacktrace() {
} \
} while (0)

#define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__)
// #define DALI_ENFORCE(...) GET_MACRO(__VA_ARGS__, ENFRC_2, ENFRC_1)(__VA_ARGS__)

template <typename T>
void DALI_ENFORCE(const T &condition, const std::string &error_string = "",
source_location loc = source_location::current()) {
if (!condition) {
throw DALIException(error_string, make_string("[", loc.source_file(), ":", loc.line(), "]"),
dali::GetStacktrace());
}
}

// Enforces that the value of 'var' is in the range [lower, upper)
#define DALI_ENFORCE_IN_RANGE(var, lower, upper) \
Expand Down
78 changes: 78 additions & 0 deletions include/dali/core/source_location.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_CORE_SOURCE_LOCATION_H_
#define DALI_CORE_SOURCE_LOCATION_H_

#include <array>
#include <cstddef>
#include <type_traits>
#include "dali/core/host_dev.h"

namespace dali {

/**
* @brief Backport of std::source_location from C++20, using compiler __builtins present in earlier
* versions. Allows to replace __FILE__ and __LINE__ macros with function calls.
*
* Call `source_location::current()` to get the current source location.
* In most cases it should be used as a default argument to a function, to record the source
* location of a function call, like so:
* void foo(source_location loc = source_location::current()) { ... }
*
*/
class source_location {
public:
DALI_HOST_DEV constexpr source_location() = default;
DALI_HOST_DEV constexpr source_location(const source_location &) = default;
DALI_HOST_DEV constexpr source_location &operator=(const source_location &) = default;
DALI_HOST_DEV constexpr source_location(source_location &&) = default;
DALI_HOST_DEV constexpr source_location &operator=(source_location &&) = default;


DALI_HOST_DEV constexpr const char *source_file() const {
return source_file_;
}

DALI_HOST_DEV constexpr const char *function_name() const {
return function_name_;
}

DALI_HOST_DEV constexpr int line() const {
return line_;
}

/**
* @brief Get the current source location.
* The caller of this function should not override the default arguments.
*/
DALI_HOST_DEV constexpr static source_location current(
const char *source_file = __builtin_FILE(), const char *function_name = __builtin_FUNCTION(),
int line_ = __builtin_LINE()) {
return {source_file, function_name, line_};
}

private:
DALI_HOST_DEV constexpr source_location(const char *source_file, const char *function_name,
int line)
: source_file_(source_file), function_name_(function_name), line_(line) {}
const char *source_file_ = "";
const char *function_name_ = "";
int line_ = 0;
};


} // namespace dali

#endif // DALI_CORE_SOURCE_LOCATION_H_
Loading