Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable operator gpu unittest #3050

Merged
merged 19 commits into from
Aug 2, 2017
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If gcc 4.8 and nvcc 8.0, both debug and release will be fine.
If gcc 5.4 and nvcc 8.0, debug is fine, but release will cause segment fault.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we constraint GCC version in Dockerfile, other than checking here?

I know checking in CMake provides a guarantee, but I am afraid adding too many such constraints would complicate our building system too much.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will constraint GCC version in Dockerfile.
But some others may compile paddle in gcc5.4 environment, so, the unittest test_framework will fall.

endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
Expand Down
7 changes: 4 additions & 3 deletions paddle/framework/detail/tensor-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/memory/memcpy.h"

namespace paddle {
Expand Down Expand Up @@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ is outside of macro, so } is needed inside macro

#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(place)) {
#else
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/add_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
Copy link
Member Author

@QiJune QiJune Aug 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to add this macro, otherwise, will cause segment fault

#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"

Expand Down
1 change: 1 addition & 0 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"

REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"

REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
1 change: 1 addition & 0 deletions paddle/operators/rowwise_add_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"

REGISTER_OP_GPU_KERNEL(rowwise_add,
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/sgd_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"

REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
1 change: 1 addition & 0 deletions paddle/operators/sigmoid_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"

REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
1 change: 1 addition & 0 deletions paddle/operators/softmax_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"

Expand Down
12 changes: 6 additions & 6 deletions paddle/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ inline void throw_on_error(T e) {
throw_on_error(e, "");
}

#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)

#define PADDLE_ENFORCE(...) \
Expand Down
58 changes: 49 additions & 9 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
Expand Down Expand Up @@ -54,6 +56,14 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1);
}

bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU
return false;
#else
return true;
#endif
}

PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");

Expand All @@ -68,15 +78,27 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim));
})
.def("alloc_float",
[](pd::Tensor& self) {
self.mutable_data<float>(paddle::platform::CPUPlace());
[](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<float>(place);
})
.def("alloc_float",
[](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<float>(place);
})
.def("alloc_int",
[](pd::Tensor& self) {
self.mutable_data<int>(paddle::platform::CPUPlace());
[](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>)
.def("alloc_int",
[](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
#endif
.def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); });

Expand Down Expand Up @@ -134,9 +156,25 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME);

py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
});
.def_static("create",
[](paddle::platform::CPUPlace& place)
-> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
})
.def_static(
"create",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
});

py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());

py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());

py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator");
Expand Down Expand Up @@ -172,5 +210,7 @@ All parameter, weight, gradient are variables in Paddle.

m.def("unique_integer", UniqueIntegerGenerator);

m.def("is_compile_gpu", IsCompileGPU);

return m.ptr();
}
48 changes: 36 additions & 12 deletions paddle/pybind/tensor_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
limitations under the License. */

#pragma once
#include <paddle/framework/tensor.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <string>
#include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"

namespace py = pybind11;

Expand All @@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");

if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside;
Expand All @@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}

framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()),
(size_t)framework::arity(dst_tensor.dims()),
dims_outside,
strides);
} else {
Expand All @@ -77,19 +81,39 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
}

template <typename T>
void PyTensorSetFromArray(
void PyCPUTensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array) {
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}

self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
}

#ifndef PADDLE_ONLY_CPU
template <typename T>
void PyCUDATensorSetFromArray(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial specialization of function template is not allowed. So just define another function here

framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}

self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
paddle::platform::GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
}
#endif

} // namespace pybind
} // namespace paddle
1 change: 0 additions & 1 deletion python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ add_python_test(test_framework
test_fc_op.py
test_add_two_op.py
test_sgd_op.py
test_cross_entropy_op.py
test_mul_op.py
test_sigmoid_op.py
test_softmax_op.py
Expand Down
61 changes: 34 additions & 27 deletions python/paddle/v2/framework/tests/op_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,48 @@ def test_all(self):
self.assertIsNotNone(func)

scope = core.Scope(None)

kwargs = dict()

for in_name in func.all_input_args:
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.create_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr)
else:
kwargs[in_name] = "@EMPTY@"
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
Copy link
Member Author

@QiJune QiJune Aug 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run CPU OpKernel first, and then GPU OpKernel


for place in places:
for in_name in func.all_input_args:
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.create_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr, place)
else:
kwargs[in_name] = "@EMPTY@"

for out_name in func.all_output_args:
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.create_var(out_name).get_tensor()
for out_name in func.all_output_args:
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.create_var(out_name).get_tensor()

for attr_name in func.all_attr_args:
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)
for attr_name in func.all_attr_args:
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)

op = func(**kwargs)
op = func(**kwargs)

op.infer_shape(scope)
op.infer_shape(scope)

ctx = core.DeviceContext.cpu_context()
op.run(scope, ctx)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)

for out_name in func.all_output_args:
actual = numpy.array(scope.get_var(out_name).get_tensor())
expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3)
for out_name in func.all_output_args:
actual = numpy.array(scope.get_var(out_name).get_tensor())
expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3)

obj.test_all = test_all
return obj
4 changes: 2 additions & 2 deletions python/paddle/v2/framework/tests/test_add_two_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase):

def setUp(self):
self.type = "add_two"
self.X = numpy.random.random((342, 345)).astype("float32")
self.Y = numpy.random.random((342, 345)).astype("float32")
self.X = numpy.random.random((102, 105)).astype("float32")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes, GPU memory is not enough for unit-test, so reduce the size

self.Y = numpy.random.random((102, 105)).astype("float32")
self.Out = self.X + self.Y


Expand Down
Loading