Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add varbasecopy func to fix the ParamBase type bug in layers.to API #32789

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,62 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
if (!PyTuple_Check(_index)) Py_DecRef(index);
}

template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add else branch and throw error to avoid bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name();
dst.SetPersistable(src->Persistable());
dst.SetDataType(src->DataType());
dst.SetType(src->Type());
dst.SetOverridedStopGradient(src->OverridedStopGradient());
if (!src->SharedVar()->IsEmpty()) {
if (src->Var().IsType<framework::LoDTensor>()) {
auto &src_tensor = src->Var().Get<framework::LoDTensor>();
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, dst_device, dst_tensor);
if (blocking) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!blocking), do we need IncreaseVarbaseReferenceCountUntilCopyComplete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_tensor.place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
} else if (src->Var().IsType<framework::SelectedRows>()) {
auto &src_selected_rows = src->Var().Get<framework::SelectedRows>();
auto *dst_selected_rows =
dst.MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
framework::TensorCopy(src_selected_rows.value(), dst_device,
dst_selected_rows->mutable_value());
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_selected_rows.value().place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
}

if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(src, dst_device);
}

} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The source Tensor(%s) can not copy when it is empty.", src->Name()));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The destion Tensor(%s) can not copy when it is not empty.",
dst.Name()));
}
}

// Bind Methods
void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr;
Expand Down Expand Up @@ -1639,6 +1695,11 @@ void BindImperative(py::module *m_ptr) {
self.nrings_ = nrings;
});

m.def("varbase_copy", &VarBaseCopy<platform::Place>);
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need XPUPlace here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);

m.def(
"dygraph_partial_grad",
[](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph import no_grad
import paddle.utils.deprecated as deprecated
Expand Down Expand Up @@ -1427,8 +1427,19 @@ def transform(t, device, dtype, blocking):
dtype = t.dtype

new_t = t._copy_to(device, blocking)
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
if isinstance(t, framework.ParamBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ParamBase._copy_to is still error, whether override the ParamBase._copy_to method to fix this bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if dtype is not None and dtype != t.dtype:
framework._dygraph_tracer().trace_op(
type='cast',
inputs={'X': new_t},
outputs={'Out': new_t},
attrs={
'in_dtype': t.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
else:
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)

return new_t

Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -5855,6 +5855,13 @@ def __deepcopy__(self, memo):
new_param.copy_(self, True)
return new_param

def _copy_to(self, device, blocking):
print("in ParamBase copy_to func")
state = copy.deepcopy(self.__dict__)
new_param = ParamBase(self.shape, self.dtype, **state)
core.varbase_copy(self, new_param, device, blocking)
return new_param

__repr__ = __str__


Expand Down
14 changes: 9 additions & 5 deletions python/paddle/fluid/tests/unittests/test_base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def setUp(self):
self.linear.register_buffer("buf_name", buffer, persistable=True)

sublayer = paddle.nn.Conv1D(3, 2, 3)
self.linear.add_sublayer(1, sublayer)
self.linear.add_sublayer("1", sublayer)

def test_to_api(self):
self.linear.to(dtype='double')
Expand All @@ -351,8 +351,8 @@ def test_to_api(self):
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)

self.linear.to()
self.assertEqual(self.linear.weight.dtype,
Expand All @@ -361,8 +361,10 @@ def test_to_api(self):
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))

if paddle.fluid.is_compiled_with_cuda():
self.linear.to(device=paddle.CUDAPlace(0))
Expand All @@ -384,6 +386,8 @@ def test_to_api(self):
))
self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))

self.linear.to(device=paddle.CPUPlace())
self.assertTrue(self.linear.weight.place.is_cpu_place())
Expand Down