Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add varbasecopy func to fix the ParamBase type bug in layers.to API #32789

Merged

Conversation

MingMingShangTian
Copy link
Contributor

PR types

Bug fixes

PR changes

APIs

Describe

Fix the bug that the layers.to API will not keep the parameters type. It will change the type from paddle.fluid.framework.ParamBase to paddle.Tensor and lose the attribution of raw type.

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 7, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@MingMingShangTian MingMingShangTian changed the title add varbasecopy func to fix the paraBase type bug in layers.to API add varbasecopy func to fix the ParamBase type bug in layers.to API May 8, 2021
static void VarBaseCopy(const imperative::VarBase &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add else branch and throw error to avoid bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dst.SetDataType(src.DataType());
dst.SetType(src.Type());
dst.SetOverridedStopGradient(src.OverridedStopGradient());
if (!src.SharedVar()->IsEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above, add else branch and throw error to avoid bug, can use PADDLE_THROW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -1639,6 +1672,10 @@ void BindImperative(py::module *m_ptr) {
self.nrings_ = nrings;
});

m.def("varbase_copy", &VarBaseCopy<platform::Place>);
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need XPUPlace here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

new_t = t._copy_to(device, blocking)
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
if isinstance(t, framework.ParamBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ParamBase._copy_to is still error, whether override the ParamBase._copy_to method to fix this bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

new_t = new_t.cast(dtype=dtype)
if isinstance(t, framework.ParamBase):
state = copy.deepcopy(t.__dict__)
new_param = framework.ParamBase(t.shape, dtype, **state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否给ParamBase添加一个_copy_to的方法覆盖原来的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加了

auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, dst_device, dst_tensor);
if (blocking) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!blocking), do we need IncreaseVarbaseReferenceCountUntilCopyComplete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MingMingShangTian MingMingShangTian merged commit 067f558 into PaddlePaddle:develop May 12, 2021
@MingMingShangTian MingMingShangTian deleted the fix_layer_to_bug_v2 branch May 12, 2021 02:43
@MingMingShangTian MingMingShangTian restored the fix_layer_to_bug_v2 branch May 13, 2021 02:38
@MingMingShangTian MingMingShangTian deleted the fix_layer_to_bug_v2 branch May 13, 2021 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants