Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WAR for __has_include #124

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions jitify.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,73 @@ inline bool load_source(
comment;
}

// WAR where nvrtc can fail to correctly return if an include is present
size_t has_include_start = cleanline.find("__has_include");
if (has_include_start != std::string::npos) {
// find subsequent opening and closing braces
size_t open = cleanline.find("(", has_include_start);
size_t close = cleanline.find(")", open);
if (!(open == std::string::npos || close == std::string::npos)) {
std::string has_name = cleanline.substr(open + 1, close - open - 1);

if (has_name.find("<") == std::string::npos && has_name.find("\"") == std::string::npos)
throw std::runtime_error("Malformed __has_include statement (" +
filename + ":" + std::to_string(linenum) +
")");
// are we using quote includes or angle brackets?
// (test using angle brackets, since quotes are valid around angle brackets)
bool quote_include = has_name.find("<") == std::string::npos;
size_t header_start = (quote_include ? has_name.find("\"") : has_name.find("<")) + 1;
size_t header_count = has_name.find(quote_include ? "\"" : ">", header_start) - header_start;
if (has_name.find(quote_include ? "\"" : ">", header_start) == std::string::npos)
throw std::runtime_error("Malformed __has_include statement (" +
filename + ":" + std::to_string(linenum) +
")");

if (header_count != 0) {
std::string has_include_name =
has_name.substr(header_start, header_count);

#if JITIFY_PRINT_HEADER_PATHS
std::cout << "Found #if __has_include(" << has_name << ")"
<< " from " << filename << ":" << linenum << std::endl;
#endif
// Try loading from filesystem
bool found_file = false;
std::string has_include_fullpath =
path_join(current_dir, has_include_name);
if (quote_include) {
file_stream.open(has_include_fullpath.c_str());
if (file_stream) found_file = true;
}
// Search include directories
if (!found_file) {
for (int i = 0; i < (int)include_paths.size(); ++i) {
has_include_fullpath =
path_join(include_paths[i], has_include_name);
file_stream.open(has_include_fullpath.c_str());
if (file_stream) {
found_file = true;
break;
}
}
if (!found_file) {
// Try loading from builtin headers
has_include_fullpath =
path_join("__jitify_builtin", has_include_name);
auto it = get_jitsafe_headers_map().find(has_include_name);
if (it != get_jitsafe_headers_map().end()) {
found_file = true;
}
}
}

line = cleanline.substr(0, has_include_start) + (found_file ? "(1)" : "(0)") +
cleanline.substr(close + 1);
}
}
}

source += line + "\n";
}
// HACK TESTING (WAR for cub)
Expand Down
58 changes: 56 additions & 2 deletions jitify_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -961,8 +961,6 @@ static const char* const builtin_numeric_cuda_std_limits_program_source =
"builtin_numeric_cuda_std_limits_program\n"
"#include <climits>\n"
"#include <limits>\n"
"#include <cuda/std/climits>\n" // test fails without this explicit include
"#include <cuda/std/limits>\n"
"struct MyType {};\n"
"namespace cuda {\n"
"namespace std {\n"
Expand Down Expand Up @@ -1138,6 +1136,62 @@ TEST(JitifyTest, EnvVarOptions) {
setenv("JITIFY_OPTIONS", "", true);
}

static const char* const has_include_source = R"(
#if __has_include(<limits>)
#else
#error __has_include failed
#endif

#if 1 && __has_include(<limits>)
#else
#error __has_include failed
#endif

#if __has_include(<limits>) && 0
#error __has_include failed
#else
#endif

#if __has_include("<limits>")
#else
#error __has_include failed
#endif

#if __has_include("limits")
#else
#error __has_include failed
#endif

#if __has_include("example_headers/my_header1.cuh")
#else
#error __has_include failed
#endif

// check we don't touch these
#if defined(__has_include)
#else
#error __has_include failed
#endif

#if !defined(__has_include)
#error __has_include failed
#endif

__global__ void has_include_kernel() { }
)";

TEST(JitifyTest, HasInclude) {
// Checks that cassert works as expected
jitify::JitCache kernel_cache;
auto program = kernel_cache.program(has_include_source);
dim3 grid(1);
dim3 block(1);
CHECK_CUDA((program.kernel("has_include_kernel")
.instantiate<>()
.configure(grid, block)
.launch()));
}

// NOTE: This MUST be the last test in the file, due to sticky CUDA error.
TEST(JitifyTest, AssertHeader) {
// Checks that cassert works as expected
Expand Down