Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyg.cuda_version() #4

Merged
merged 2 commits into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions pyg_lib/csrc/library.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "library.h"

#ifdef USE_PYTHON
#include <Python.h>
#endif

#ifdef WITH_CUDA
#include <cuda.h>
#endif

#include <torch/library.h>

// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension.
// For PyMODINIT_FUNC to work, we need to include Python.h
#ifdef _WIN32
#ifdef USE_PYTHON
PyMODINIT_FUNC PyInit__C(void) { return NULL; }
#endif // USE_PYTHON
#endif // _WIN32

namespace pyg {

int64_t cuda_version() {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}

TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def("cuda_version", &cuda_version); }

} // namespace pyg
17 changes: 17 additions & 0 deletions pyg_lib/csrc/library.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "macros.h"

namespace pyg {

PYG_API int64_t cuda_version();

namespace detail {

extern "C" PYG_INLINE_VARIABLE auto _register_ops = &cuda_version;
#ifdef HINT_MSVC_LINKER_INCLUDE_SYMBOL
#pragma comment(linker, "/include:_register_ops")
#endif

} // namespace detail
} // namespace pyg
22 changes: 22 additions & 0 deletions pyg_lib/csrc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#ifdef _WIN32
#if defined(pyg_EXPORTS)
#define PYG_API __declspec(dllexport)
#else
#define PYG_API __declspec(dllimport)
#endif
#else
#define PYG_API
#endif

#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define PYG_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define PYG_INLINE_VARIABLE __declspec(selectany)
#define HINT_MSVC_LINKER_INCLUDE_SYMBOL
#else
#define PYG_INLINE_VARIABLE __attribute__((weak))
#endif
#endif