Skip to content

Commit

Permalink
Add CUDAPinnedPlace
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Mar 27, 2018
1 parent 158d6c4 commit ab601c1
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 10 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor<bool> {
const platform::CPUPlace& cpu) const {
return *out.data<bool>();
}

bool GetResult(const framework::Tensor& out,
const platform::CUDAPinnedPlace& cpu) const {
return *out.data<bool>();
}
};

template <typename Predicate>
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const {
}

size_t Usage::operator()(const platform::CUDAPinnedPlace& cuda_pinned) const {
#ifdef PADDLE_WITH_CUDA
return Used(cuda_pinned);
#else
PADDLE_THROW("'CUDAPinnedPlace' is not supported in CPU only device.");
#endif
}

size_t memory_usage(const platform::Place& p) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,14 @@ void set_constant_with_place<platform::CPUPlace>(
TensorSetConstantCPU(tensor, value));
}

template <>
void set_constant_with_place<platform::CUDAPinnedPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstantCPU(tensor, value));
}

struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
TensorSetConstantWithPlace(const platform::DeviceContext& context,
framework::Tensor* tensor, float value)
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool(
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
} else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(
p,
PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
Expand Down Expand Up @@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }

cudaStream_t CUDADeviceContext::stream() const { return stream_; }

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
: place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif

#ifdef PADDLE_WITH_MKLDNN
Expand Down
27 changes: 17 additions & 10 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,24 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
};

// Currently, CUDAPinnedDeviceContext is only used to data copying.
// class CUDAPinnedDeviceContext : public DeviceContext {
// public:
// CUDAPinnedDeviceContext();
// explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
//
// Place GetPlace() const override;
//
// private:
// CUDAPinnedPlace place_;
//};
class CUDAPinnedDeviceContext : public DeviceContext {
public:
CUDAPinnedDeviceContext();
explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

Place GetPlace() const override;

Eigen::DefaultDevice* eigen_device() const;

private:
CUDAPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
using TYPE = CUDAPinnedDeviceContext;
};
#endif

#ifdef PADDLE_WITH_MKLDNN
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ struct PlaceVisitorWrapper

typename Visitor::result_type operator()(
const CUDAPinnedPlace &cuda_pinned) const {
#ifdef PADDLE_WITH_CUDA
return visitor_(cuda_pinned);
#else
PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda_pinned");
return typename Visitor::result_type();
#endif
}
};

Expand Down

0 comments on commit ab601c1

Please sign in to comment.