Skip to content

Commit

Permalink
Re-apply windows diff D4657831
Browse files Browse the repository at this point in the history
Summary:
(Note: previous revert was due to a race condition between D4657831 and
D4659953 that I failed to catch.)

After this, we should have contbuild guarding the Windows build both with
and without CUDA.

This includes a series of changes that are needed to make Windows build,
specifically:

(1) Various flags that are needed in the cmake system, specially dealing
with /MD, /MT, cuda, cudnn, whole static linking, etc.
(2) Contbuild scripts based on appveyo.
(3) For Windows build, note that one will need to use "cmake --build" to
build stuff so that the build type is consistent between configuration and
actual build. see scripts\build_windows.bat for details.
(4) In logging.h, ERROR is already defined by Windows. I don't have a good
solution now, and as a result, LOG(ERROR) on windows is going to be
LOG(INFO).
(5) variable length array is not supported by MSVC (and it is not part of
C++ standard). As a result I replaced them with vectors.
(6) sched.h is not available on Windows, so akyrola 's awesome simple
async net might encounter some slowdown due to no affinity setting on
Windows.
(7) MSVC has a bug that does not work very well with template calls inide
a templated function call, which is a known issue that should be fixed in
MSVC 2017. However for now this means changes to conv_op_impl.h and
recurrent_net_op.h. No actual functionalities are changed.
(8) std host function calls are not supported in CUDA8+MSVC, so I changed
lp_pool (and maybe a few others) to use cuda device functions.
(9) The current Scale and Axpy has heavy templating that does not work
well with MSVC. As a result I reverted azzolini 's changes to the Scale
and Axpy interface, moved the fixed-length version to ScaleFixedSize and
AxpyFixedSize.
(10) CUDA + MSVC does not deal with Eigen well, so I guarded all Eigen
parts to only the non-CUDA part.
(11) In conclusion, it is fun but painful to deal with visual c++.

Differential Revision: D4666745

fbshipit-source-id: 3c9035083067bdb19a16d9c345c1ce66b6a86600
  • Loading branch information
Yangqing Jia authored and facebook-github-bot committed Mar 7, 2017
1 parent 0889013 commit e0ee9e9
Show file tree
Hide file tree
Showing 32 changed files with 372 additions and 253 deletions.
16 changes: 11 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,19 @@ endif()
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "binaries")

# ---[ Build flags
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "MSVC")
message(WARNING "Develop note: when all errors are addressed, turn on warning.")
message(STATUS "Adding no warning argument to the compiler")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
else()
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
else()
if (NOT ${BUILD_SHARED_LIBS})
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endif()
endif()

if (CAFFE2_CPU_FLAGS)
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Caffe2 is released under the [BSD 2-Clause license](https://github.com/Yangqing/

## Building Caffe2

[![Build Status](https://travis-ci.org/caffe2/caffe2.svg?branch=master)](https://travis-ci.org/caffe2/caffe2)
[![Travis Build Status](https://travis-ci.org/caffe2/caffe2.svg?branch=master)](https://travis-ci.org/caffe2/caffe2)

[![Windows Build status](https://ci.appveyor.com/api/projects/status/kec4ta779stuyb83?svg=true)](https://ci.appveyor.com/project/Yangqing/caffe2)

git clone --recursive https://github.com/caffe2/caffe2.git
cd caffe2
Expand Down
30 changes: 30 additions & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
version: '{build}'
clone_folder: c:\projects\caffe2
environment:
matrix:
- USE_CUDA: OFF
CMAKE_BUILD_TYPE: Release

- USE_CUDA: ON
CMAKE_BUILD_TYPE: Release

- USE_CUDA: OFF
CMAKE_BUILD_TYPE: Debug

# Currently, CUDA + Debug does not work due to a error of using
# std::_Debug_lt in device code. Not sure where this comes from yet,
# but it is probably safe to assume that very few are going to build
# debug mode with CUDA and Windows.
#- USE_CUDA: ON
# CMAKE_BUILD_TYPE: Debug

install:
- cmd: c:\projects\caffe2\scripts\appveyor\install.bat

build_script:
- cmd: >-
cd c:\projects\caffe2
git submodule update --init
call scripts\build_windows.bat
2 changes: 1 addition & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ list(APPEND Caffe2_MAIN_LIBS_ORDER Caffe2_CPU)
if (BUILD_SHARED_LIBS)
list(APPEND Caffe2_MAIN_LIBS Caffe2_CPU)
else()
caffe_add_whole_archive_flag(Caffe2_CPU tmp)
caffe_add_whole_archive_flag(Caffe2_CPU tmp)
list(APPEND Caffe2_MAIN_LIBS ${tmp})
endif()

Expand Down
3 changes: 2 additions & 1 deletion caffe2/binaries/convert_caffe_image_db.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ int main(int argc, char** argv) {
} else {
// float data not supported right now.
CAFFE_ENFORCE_EQ(datum.float_data_size(), 0);
char buffer[datum.data().size()];
std::vector<char> buffer_vec(datum.data().size());
char* buffer = buffer_vec.data();
// swap order from CHW to HWC
int channels = datum.channels();
int size = datum.height() * datum.width();
Expand Down
23 changes: 23 additions & 0 deletions caffe2/core/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ private: \
#define CAFFE2_ALIGNED(x) __attribute__((aligned(x)))
#endif

/**
* Macro for marking functions as having public visibility.
* Ported from folly/CPortability.h
*/
#ifndef __GNUC_PREREQ
#if defined __GNUC__ && defined __GNUC_MINOR__
#define __GNUC_PREREQ(maj, min) \
((__GNUC__ << 16) + __GNUC_MINOR__ >= ((maj) << 16) + (min))
#else
#define __GNUC_PREREQ(maj, min) 0
#endif
#endif

#if defined(__GNUC__)
#if __GNUC_PREREQ(4, 9)
#define CAFFE2_EXPORT [[gnu::visibility("default")]]
#else
#define CAFFE2_EXPORT __attribute__((__visibility__("default")))
#endif
#else
#define CAFFE2_EXPORT
#endif

// make_unique is a C++14 feature. If we don't have 14, we will emulate
// its behavior. This is copied from folly/Memory.h
#if __cplusplus >= 201402L || \
Expand Down
3 changes: 1 addition & 2 deletions caffe2/core/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ struct EnforceOK {};
class EnforceFailMessage {
public:
constexpr /* implicit */ EnforceFailMessage(EnforceOK) : msg_(nullptr) {}

EnforceFailMessage(EnforceFailMessage&&) = default;
EnforceFailMessage(const EnforceFailMessage&) = delete;
EnforceFailMessage& operator=(EnforceFailMessage&&) = delete;
Expand All @@ -180,7 +179,7 @@ class EnforceFailMessage {
msg_ = new std::string(std::move(msg));
}
inline bool bad() const {
return msg_;
return msg_ != nullptr;
}
std::string get_message_and_free(std::string&& extra) const {
std::string r;
Expand Down
6 changes: 6 additions & 0 deletions caffe2/core/logging_is_not_google_glog.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

// Log severity level constants.
const int FATAL = 3;
#if !defined(_MSC_VER) || !defined(ERROR)
// Windows defines the ERROR macro already, and as a result we will
// simply use that one. The downside is that one will now mix LOG(INFO)
// and LOG(ERROR) because ERROR is defined to be zero. Anyway, the
// recommended way is to use glog so fixing this is a low-pri item.
const int ERROR = 2;
#endif
const int WARNING = 1;
const int INFO = 0;
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
Expand Down
12 changes: 11 additions & 1 deletion caffe2/core/net_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#include <mutex>
#include <stack>

#if !defined(_MSC_VER)
#include <sched.h>
#endif

#include "caffe2/core/common_gpu.h"
#include "caffe2/core/flags.h"
#include "caffe2/core/operator.h"
Expand Down Expand Up @@ -258,6 +261,10 @@ void GPUExecutor::Release(int gpu) {
}

void GPUExecutor::set_affinity() {
// TODO: find a Windows-compatible affinity setting approach.
// Currently, set_affinity has no effect in Windows. The code is still
// correct with possible slowdowns.
#if !defined(_MSC_VER)
/* Set CPU affinity */
int num_cores = std::thread::hardware_concurrency();
if (num_cores > 0) {
Expand All @@ -269,6 +276,7 @@ void GPUExecutor::set_affinity() {
LOG(WARNING) << "Could not set CPU affinity";
}
}
#endif
}

// Worker that takes list of operators from the queue
Expand Down Expand Up @@ -363,7 +371,9 @@ class SingleThreadAsyncNet : public SimpleNet {
}

bool RunAsync() {
LOG(FATAL) << "RunAsync() not implemented for singlethread_async net";
CAFFE_THROW("RunAsync() not implemented for singlethread_async net");
// Just to suppress compiler warning.
return false;
}

private:
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/typeid.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class TypeMeta {
* is generated during run-time. Do NOT serialize the id for storage.
*/
template <typename T>
[[gnu::visibility("default")]] static CaffeTypeId Id();
CAFFE2_EXPORT static CaffeTypeId Id();

/**
* Returns the item size of the type. This is equivalent to sizeof(T).
Expand Down
15 changes: 8 additions & 7 deletions caffe2/cuda_rtc/common_rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,23 @@ class CudaRTCFunction {
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size));
char nvrtc_log[log_size];
NVRTC_CHECK(nvrtcGetProgramLog(prog, nvrtc_log));
vector<char> nvrtc_log(log_size);
NVRTC_CHECK(nvrtcGetProgramLog(prog, nvrtc_log.data()));
LOG(FATAL) << "Compilation failure for nvrtc("
<< nvrtcGetErrorString(compile_result)
<< "): \n" << nvrtc_log;
<< nvrtcGetErrorString(compile_result) << "): \n"
<< nvrtc_log.data();
}
size_t ptx_size;
NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size));
char nvrtc_ptx[ptx_size];
NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx));
vector<char> nvrtc_ptx(ptx_size);
NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data()));
NVRTC_CHECK(nvrtcDestroyProgram(&prog));
// After compilation, load the module.
if (module_loaded_) {
CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
}
CUDA_DRIVERAPI_ENFORCE(cuModuleLoadDataEx(&module_, nvrtc_ptx, 0, 0, 0));
CUDA_DRIVERAPI_ENFORCE(
cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0));
module_loaded_ = true;
CUDA_DRIVERAPI_ENFORCE(
cuModuleGetFunction(&kernel_, module_, name.c_str()));
Expand Down
3 changes: 2 additions & 1 deletion caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class ElementwiseRTCOp final : public Operator<CUDAContext> {
static_assert(sizeof(void*) == sizeof(size_t),
"The argbuffer relies on the assumption that void* and "
"size_t have the same size.");
size_t argBuffer[InputSize() + OutputSize() + 1];
vector<size_t> argBuffer_vec(InputSize() + OutputSize() + 1);
size_t* argBuffer = argBuffer_vec.data();
CAFFE_ENFORCE(
Input(0).size() < std::numeric_limits<int>::max(),
"The kernel function currently only supports int index.");
Expand Down
9 changes: 6 additions & 3 deletions caffe2/operators/conv_op_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,15 +572,18 @@ bool CudnnConvGradientOp<T>::RunOnDevice() {
auto* dX =
Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
dX->ResizeLike(X);
const T* filter_data = filter.template data<T>();
const T* dYdata = dY.template data<T>();
T* dXdata = dX->template mutable_data<T>();
CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithmEx(
state->cudnn_handle(),
filter_desc_,
filter.template data<T>(),
filter_data,
top_desc_,
dY.template data<T>(),
dYdata,
conv_desc_,
bottom_desc_,
dX->template mutable_data<T>(),
dXdata,
kNUM_CUDNN_BWD_DATA_ALGS,
&returned_algo_count,
data_perf_stat.data(),
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/instance_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ __global__ void InstanceNormGradientKernel(
output_grad_offset += dim_stride;
}

temp *= -std::pow(inv_stdev_data[i], 3.0) / dim;
temp *= -powf(inv_stdev_data[i], 3.0) / dim;

input_grad_offset = input_grad_data + n * N_stride + c * C_stride;
output_grad_offset = output_grad_data + n * N_stride + c * C_stride;
Expand Down
50 changes: 39 additions & 11 deletions caffe2/operators/lp_pool_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,32 @@ namespace {
class LpPool {};
} // namespace

namespace {
template <typename T>
inline __device__ T cuda_pow(T x, T y);

template <typename T>
inline __device__ T cuda_abs(T x);

template <>
inline __device__ float cuda_pow<float>(float x, float y) {
return powf(x, y);
}
template <>
inline __device__ double cuda_pow<double>(double x, double y) {
return pow(x, y);
}

template <>
inline __device__ float cuda_abs(float x) {
return fabsf(x);
}
template <>
inline __device__ double cuda_abs(double x) {
return fabs(x);
}
}

namespace {
template <typename T>
__global__ void LpPoolForwardNCHW(
Expand Down Expand Up @@ -46,11 +72,11 @@ __global__ void LpPoolForwardNCHW(
int bottom_offset = (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
top_data[index] +=
std::pow(std::abs(bottom_data[bottom_offset + h * width + w]), p);
top_data[index] += cuda_pow<T>(
cuda_abs(bottom_data[bottom_offset + h * width + w]), p);
}
}
top_data[index] = std::pow(top_data[index], 1.0 / p);
top_data[index] = cuda_pow<T>(top_data[index], 1.0 / p);
}
}

Expand Down Expand Up @@ -87,12 +113,12 @@ __global__ void LpPoolForwardNHWC(
int bottom_offset = n * height * width * channels + c;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
output += std::pow(
std::abs(bottom_data[bottom_offset + (h * width + w) * channels]),
output += cuda_pow<T>(
cuda_abs(bottom_data[bottom_offset + (h * width + w) * channels]),
p);
}
}
top_data[index] = std::pow(output, 1.0 / p);
top_data[index] = cuda_pow<T>(output, 1.0 / p);
}
}

Expand Down Expand Up @@ -143,8 +169,9 @@ __global__ void LpPoolBackwardNCHW(
hstart = max(hstart, 0);
wstart = max(wstart, 0);
gradient += top_diff_slice[ph * pooled_width + pw] *
bottom_data[index] * std::pow(std::abs(bottom_data[index]), p - 2) /
std::pow(top_data_slice[ph * pooled_width + pw], p - 1);
bottom_data[index] *
cuda_pow<T>(cuda_abs(bottom_data[index]), p - 2) /
cuda_pow<T>(top_data_slice[ph * pooled_width + pw], p - 1);
}
}
bottom_diff[index] = gradient;
Expand Down Expand Up @@ -197,9 +224,10 @@ __global__ void LpPoolBackwardNHWC(
hstart = max(hstart, 0);
wstart = max(wstart, 0);
gradient += top_diff_slice[(ph * pooled_width + pw) * channels] *
bottom_data[index] * std::pow(std::abs(bottom_data[index]), p - 2) /
std::pow(top_data_slice[(ph * pooled_width + pw) * channels],
p - 1);
bottom_data[index] *
cuda_pow<T>(cuda_abs(bottom_data[index]), p - 2) /
cuda_pow<T>(top_data_slice[(ph * pooled_width + pw) * channels],
p - 1);
}
}
bottom_diff[index] = gradient;
Expand Down
Loading

0 comments on commit e0ee9e9

Please sign in to comment.