-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add varbasecopy func to fix the ParamBase type bug in layers.to API #32789
Changes from 1 commit
e02c7bb
21a7989
18b6d1c
9e9f8c5
2ecfe9d
21ab8fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,6 +469,39 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, | |
if (!PyTuple_Check(_index)) Py_DecRef(index); | ||
} | ||
|
||
template <typename P> | ||
static void VarBaseCopy(const imperative::VarBase &src, | ||
imperative::VarBase &dst, const P &dst_device, | ||
const bool blocking) { | ||
if (dst.SharedVar()->IsEmpty()) { | ||
VLOG(3) << "deep copy Variable from " << src.Name() << " to " << dst.Name(); | ||
dst.SetPersistable(src.Persistable()); | ||
dst.SetDataType(src.DataType()); | ||
dst.SetType(src.Type()); | ||
dst.SetOverridedStopGradient(src.OverridedStopGradient()); | ||
if (!src.SharedVar()->IsEmpty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same above, add else branch and throw error to avoid bug, can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// const platform::Place& place = src.Place(); | ||
if (src.Var().IsType<framework::LoDTensor>()) { | ||
auto &src_tensor = src.Var().Get<framework::LoDTensor>(); | ||
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>(); | ||
dst_tensor->set_lod(src_tensor.lod()); | ||
framework::TensorCopy(src_tensor, dst_device, dst_tensor); | ||
} else if (src.Var().IsType<framework::SelectedRows>()) { | ||
auto &src_selected_rows = src.Var().Get<framework::SelectedRows>(); | ||
auto *dst_selected_rows = | ||
dst.MutableVar()->GetMutable<framework::SelectedRows>(); | ||
dst_selected_rows->set_height(src_selected_rows.height()); | ||
dst_selected_rows->set_rows(src_selected_rows.rows()); | ||
framework::TensorCopy(src_selected_rows.value(), dst_device, | ||
dst_selected_rows->mutable_value()); | ||
} | ||
if (blocking) { | ||
platform::DeviceContextPool::Instance().Get(dst_device)->Wait(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Bind Methods | ||
void BindImperative(py::module *m_ptr) { | ||
auto &m = *m_ptr; | ||
|
@@ -1639,6 +1672,10 @@ void BindImperative(py::module *m_ptr) { | |
self.nrings_ = nrings; | ||
}); | ||
|
||
m.def("varbase_copy", &VarBaseCopy<platform::Place>); | ||
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>); | ||
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need XPUPlace here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
m.def( | ||
"dygraph_partial_grad", | ||
[](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
from paddle.fluid import framework | ||
from ..param_attr import ParamAttr | ||
from paddle.fluid.executor import Executor, global_scope | ||
from paddle.fluid.framework import in_dygraph_mode | ||
from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ | ||
from paddle.fluid.framework import _current_expected_place as _get_device | ||
from paddle.fluid.dygraph import no_grad | ||
import paddle.utils.deprecated as deprecated | ||
|
@@ -1426,11 +1426,28 @@ def transform(t, device, dtype, blocking): | |
if dtype is None: | ||
dtype = t.dtype | ||
|
||
new_t = t._copy_to(device, blocking) | ||
if dtype is not None and dtype != t.dtype: | ||
new_t = new_t.cast(dtype=dtype) | ||
if isinstance(t, framework.ParamBase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
state = copy.deepcopy(t.__dict__) | ||
new_param = framework.ParamBase(t.shape, dtype, **state) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 能否给ParamBase添加一个_copy_to的方法覆盖原来的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经添加了 |
||
core.varbase_copy(t, new_param, device, blocking) | ||
|
||
if dtype is not None and dtype != t.dtype: | ||
framework._dygraph_tracer().trace_op( | ||
type='cast', | ||
inputs={'X': new_param}, | ||
outputs={'Out': new_param}, | ||
attrs={ | ||
'in_dtype': t.dtype, | ||
'out_dtype': convert_np_dtype_to_dtype_(dtype) | ||
}) | ||
|
||
return new_param | ||
else: | ||
new_t = t._copy_to(device, blocking) | ||
if dtype is not None and dtype != t.dtype: | ||
new_t = new_t.cast(dtype=dtype) | ||
|
||
return new_t | ||
return new_t | ||
|
||
self._apply(transform, device, dtype, blocking) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add else branch and throw error to avoid bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done