-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add varbasecopy func to fix the ParamBase type bug in layers.to API #32789
Changes from all commits
e02c7bb
21a7989
18b6d1c
9e9f8c5
2ecfe9d
21ab8fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,6 +469,62 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, | |
if (!PyTuple_Check(_index)) Py_DecRef(index); | ||
} | ||
|
||
template <typename P> | ||
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, | ||
imperative::VarBase &dst, const P &dst_device, | ||
const bool blocking) { | ||
if (dst.SharedVar()->IsEmpty()) { | ||
VLOG(3) << "deep copy Variable from " << src->Name() << " to " | ||
<< dst.Name(); | ||
dst.SetPersistable(src->Persistable()); | ||
dst.SetDataType(src->DataType()); | ||
dst.SetType(src->Type()); | ||
dst.SetOverridedStopGradient(src->OverridedStopGradient()); | ||
if (!src->SharedVar()->IsEmpty()) { | ||
if (src->Var().IsType<framework::LoDTensor>()) { | ||
auto &src_tensor = src->Var().Get<framework::LoDTensor>(); | ||
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>(); | ||
dst_tensor->set_lod(src_tensor.lod()); | ||
framework::TensorCopy(src_tensor, dst_device, dst_tensor); | ||
if (blocking) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if (!blocking), do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
platform::DeviceContextPool::Instance().Get(dst_device)->Wait(); | ||
auto src_device = src_tensor.place(); | ||
if (!(src_device == dst_device)) { | ||
platform::DeviceContextPool::Instance().Get(src_device)->Wait(); | ||
} | ||
} | ||
} else if (src->Var().IsType<framework::SelectedRows>()) { | ||
auto &src_selected_rows = src->Var().Get<framework::SelectedRows>(); | ||
auto *dst_selected_rows = | ||
dst.MutableVar()->GetMutable<framework::SelectedRows>(); | ||
dst_selected_rows->set_height(src_selected_rows.height()); | ||
dst_selected_rows->set_rows(src_selected_rows.rows()); | ||
framework::TensorCopy(src_selected_rows.value(), dst_device, | ||
dst_selected_rows->mutable_value()); | ||
if (blocking) { | ||
platform::DeviceContextPool::Instance().Get(dst_device)->Wait(); | ||
auto src_device = src_selected_rows.value().place(); | ||
if (!(src_device == dst_device)) { | ||
platform::DeviceContextPool::Instance().Get(src_device)->Wait(); | ||
} | ||
} | ||
} | ||
|
||
if (!blocking) { | ||
IncreaseVarbaseReferenceCountUntilCopyComplete(src, dst_device); | ||
} | ||
|
||
} else { | ||
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"The source Tensor(%s) can not copy when it is empty.", src->Name())); | ||
} | ||
} else { | ||
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"The destion Tensor(%s) can not copy when it is not empty.", | ||
dst.Name())); | ||
} | ||
} | ||
|
||
// Bind Methods | ||
void BindImperative(py::module *m_ptr) { | ||
auto &m = *m_ptr; | ||
|
@@ -1639,6 +1695,11 @@ void BindImperative(py::module *m_ptr) { | |
self.nrings_ = nrings; | ||
}); | ||
|
||
m.def("varbase_copy", &VarBaseCopy<platform::Place>); | ||
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>); | ||
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need XPUPlace here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>); | ||
|
||
m.def( | ||
"dygraph_partial_grad", | ||
[](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
from paddle.fluid import framework | ||
from ..param_attr import ParamAttr | ||
from paddle.fluid.executor import Executor, global_scope | ||
from paddle.fluid.framework import in_dygraph_mode | ||
from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ | ||
from paddle.fluid.framework import _current_expected_place as _get_device | ||
from paddle.fluid.dygraph import no_grad | ||
import paddle.utils.deprecated as deprecated | ||
|
@@ -1427,8 +1427,19 @@ def transform(t, device, dtype, blocking): | |
dtype = t.dtype | ||
|
||
new_t = t._copy_to(device, blocking) | ||
if dtype is not None and dtype != t.dtype: | ||
new_t = new_t.cast(dtype=dtype) | ||
if isinstance(t, framework.ParamBase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if dtype is not None and dtype != t.dtype: | ||
framework._dygraph_tracer().trace_op( | ||
type='cast', | ||
inputs={'X': new_t}, | ||
outputs={'Out': new_t}, | ||
attrs={ | ||
'in_dtype': t.dtype, | ||
'out_dtype': convert_np_dtype_to_dtype_(dtype) | ||
}) | ||
else: | ||
if dtype is not None and dtype != t.dtype: | ||
new_t = new_t.cast(dtype=dtype) | ||
|
||
return new_t | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add else branch and throw error to avoid bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done