-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable operator gpu unittest #3050
Changes from 16 commits
e2ba133
d510913
aa5ca8a
ff594fa
a71a9e6
2ddef13
358261f
4ecf68e
4cc4217
47d8bca
4a1f7bd
61f94f0
cf5ac58
db4d668
bc7be2a
edb5729
81cc7a3
341d188
043e983
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/memory/memcpy.h" | ||
|
||
namespace paddle { | ||
|
@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { | |
if (platform::is_cpu_place(place)) { | ||
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( | ||
boost::get<platform::CPUPlace>(place), size)); | ||
} else if (platform::is_gpu_place(place)) { | ||
#ifdef PADDLE_ONLY_CPU | ||
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove } There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. { is outside of macro, so } is needed inside macro |
||
#ifndef PADDLE_ONLY_CPU | ||
else if (platform::is_gpu_place(place)) { | ||
#else | ||
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( | ||
boost::get<platform::GPUPlace>(place), size)); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#define EIGEN_USE_GPU | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to add this macro, otherwise, will cause segment fault |
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/add_op.h" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#define EIGEN_USE_GPU | ||
#include "paddle/operators/rowwise_add_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL(rowwise_add, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#define EIGEN_USE_GPU | ||
#include "paddle/operators/sgd_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#define EIGEN_USE_GPU | ||
#include "paddle/operators/sigmoid_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#define EIGEN_USE_GPU | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/softmax_op.h" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,11 @@ | |
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <paddle/framework/tensor.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
#include <string> | ||
#include "paddle/framework/tensor.h" | ||
#include "paddle/memory/memcpy.h" | ||
#include "pybind11/numpy.h" | ||
#include "pybind11/pybind11.h" | ||
|
||
namespace py = pybind11; | ||
|
||
|
@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS> | |
struct CastToPyBufferImpl<true, I, ARGS...> { | ||
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; | ||
py::buffer_info operator()(framework::Tensor &tensor) { | ||
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), | ||
"Only CPU tensor can cast to numpy array"); | ||
|
||
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { | ||
auto dim_vec = framework::vectorize(tensor.dims()); | ||
std::vector<size_t> dims_outside; | ||
|
@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> { | |
strides[i - 1] = sizeof(CUR_TYPE) * prod; | ||
prod *= dims_outside[i - 1]; | ||
} | ||
|
||
framework::Tensor dst_tensor; | ||
if (paddle::platform::is_gpu_place(tensor.holder_->place())) { | ||
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace()); | ||
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { | ||
dst_tensor = tensor; | ||
} | ||
return py::buffer_info( | ||
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), | ||
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()), | ||
sizeof(CUR_TYPE), | ||
py::format_descriptor<CUR_TYPE>::format(), | ||
(size_t)framework::arity(tensor.dims()), | ||
(size_t)framework::arity(dst_tensor.dims()), | ||
dims_outside, | ||
strides); | ||
} else { | ||
|
@@ -77,19 +81,39 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { | |
} | ||
|
||
template <typename T> | ||
void PyTensorSetFromArray( | ||
void PyCPUTensorSetFromArray( | ||
framework::Tensor &self, | ||
py::array_t<T, py::array::c_style | py::array::forcecast> array) { | ||
py::array_t<T, py::array::c_style | py::array::forcecast> array, | ||
paddle::platform::CPUPlace &place) { | ||
std::vector<int> dims; | ||
dims.reserve(array.ndim()); | ||
for (size_t i = 0; i < array.ndim(); ++i) { | ||
dims.push_back((int)array.shape()[i]); | ||
} | ||
|
||
self.Resize(framework::make_ddim(dims)); | ||
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); | ||
auto *dst = self.mutable_data<T>(place); | ||
std::memcpy(dst, array.data(), sizeof(T) * array.size()); | ||
} | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
template <typename T> | ||
void PyCUDATensorSetFromArray( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. partial specialization of function template is not allowed. So just define another function here |
||
framework::Tensor &self, | ||
py::array_t<T, py::array::c_style | py::array::forcecast> array, | ||
paddle::platform::GPUPlace &place) { | ||
std::vector<int> dims; | ||
dims.reserve(array.ndim()); | ||
for (size_t i = 0; i < array.ndim(); ++i) { | ||
dims.push_back((int)array.shape()[i]); | ||
} | ||
|
||
self.Resize(framework::make_ddim(dims)); | ||
auto *dst = self.mutable_data<T>(place); | ||
paddle::platform::GpuMemcpySync( | ||
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); | ||
} | ||
#endif | ||
|
||
} // namespace pybind | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,41 +25,48 @@ def test_all(self): | |
self.assertIsNotNone(func) | ||
|
||
scope = core.Scope(None) | ||
|
||
kwargs = dict() | ||
|
||
for in_name in func.all_input_args: | ||
if hasattr(self, in_name): | ||
kwargs[in_name] = in_name | ||
var = scope.create_var(in_name).get_tensor() | ||
arr = getattr(self, in_name) | ||
var.set_dims(arr.shape) | ||
var.set(arr) | ||
else: | ||
kwargs[in_name] = "@EMPTY@" | ||
places = [] | ||
places.append(core.CPUPlace()) | ||
if core.is_compile_gpu(): | ||
places.append(core.GPUPlace(0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run CPU OpKernel first, and then GPU OpKernel |
||
|
||
for place in places: | ||
for in_name in func.all_input_args: | ||
if hasattr(self, in_name): | ||
kwargs[in_name] = in_name | ||
var = scope.create_var(in_name).get_tensor() | ||
arr = getattr(self, in_name) | ||
var.set_dims(arr.shape) | ||
var.set(arr, place) | ||
else: | ||
kwargs[in_name] = "@EMPTY@" | ||
|
||
for out_name in func.all_output_args: | ||
if hasattr(self, out_name): | ||
kwargs[out_name] = out_name | ||
scope.create_var(out_name).get_tensor() | ||
for out_name in func.all_output_args: | ||
if hasattr(self, out_name): | ||
kwargs[out_name] = out_name | ||
scope.create_var(out_name).get_tensor() | ||
|
||
for attr_name in func.all_attr_args: | ||
if hasattr(self, attr_name): | ||
kwargs[attr_name] = getattr(self, attr_name) | ||
for attr_name in func.all_attr_args: | ||
if hasattr(self, attr_name): | ||
kwargs[attr_name] = getattr(self, attr_name) | ||
|
||
op = func(**kwargs) | ||
op = func(**kwargs) | ||
|
||
op.infer_shape(scope) | ||
op.infer_shape(scope) | ||
|
||
ctx = core.DeviceContext.cpu_context() | ||
op.run(scope, ctx) | ||
ctx = core.DeviceContext.create(place) | ||
op.run(scope, ctx) | ||
|
||
for out_name in func.all_output_args: | ||
actual = numpy.array(scope.get_var(out_name).get_tensor()) | ||
expect = getattr(self, out_name) | ||
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul | ||
# has some diff, and could not pass unittest. So I set decimal 3 here. | ||
# And I will check this in future. | ||
numpy.testing.assert_almost_equal(actual, expect, decimal=3) | ||
for out_name in func.all_output_args: | ||
actual = numpy.array(scope.get_var(out_name).get_tensor()) | ||
expect = getattr(self, out_name) | ||
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul | ||
# has some diff, and could not pass unittest. So I set decimal 3 here. | ||
# And I will check this in future. | ||
numpy.testing.assert_almost_equal(actual, expect, decimal=3) | ||
|
||
obj.test_all = test_all | ||
return obj |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase): | |
|
||
def setUp(self): | ||
self.type = "add_two" | ||
self.X = numpy.random.random((342, 345)).astype("float32") | ||
self.Y = numpy.random.random((342, 345)).astype("float32") | ||
self.X = numpy.random.random((102, 105)).astype("float32") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes, GPU memory is not enough for unit-test, so reduce the size |
||
self.Y = numpy.random.random((102, 105)).astype("float32") | ||
self.Out = self.X + self.Y | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If gcc 4.8 and nvcc 8.0, both debug and release will be fine.
If gcc 5.4 and nvcc 8.0, debug is fine, but release will cause segment fault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we constraint GCC version in Dockerfile, other than checking here?
I know checking in CMake provides a guarantee, but I am afraid adding too many such constraints would complicate our building system too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will constraint GCC version in Dockerfile.
But some others may compile paddle in gcc5.4 environment, so, the unittest test_framework will fall.