Skip to content

Commit

Permalink
Revert "[Cherry-pick] fix set_value with scalar grad (PaddlePaddle#60930
Browse files Browse the repository at this point in the history
)"

This reverts commit 1aa5f4b.
  • Loading branch information
hanhaowen-mt committed May 13, 2024
1 parent 4430ca2 commit e292842
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 530 deletions.
44 changes: 25 additions & 19 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,32 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("set_value_grad");
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));

if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
}
};

Expand Down
108 changes: 45 additions & 63 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,

// step3: Dealing with advanced indexing
std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim, trans_dim;
std::vector<int> trans_back_dim;
int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;

paddle::Tensor transed_tensor = dealWithAdvancedIndex(out,
Expand All @@ -1385,8 +1385,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&rank_of_new_dim);

if (transed_index.size() == 1 &&
transed_index[0].dtype() == phi::DataType::BOOL) {
Expand Down Expand Up @@ -1608,70 +1607,58 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Parse values
std::vector<phi::Scalar> values;
PADDLE_ENFORCE(
PyCheckTensor(value_obj),
platform::errors::InvalidArgument("The value must be a Tensor"));

paddle::Tensor value_tensor =
dealWithValues(tensor, value_obj, &values, has_advanced_index);
reinterpret_cast<TensorObject*>(value_obj)->tensor;

if (!has_advanced_index) {
// use set_value OP if there is no advanced index

// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized()) {
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
}

// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor);
// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
self->tensor = set_value__ad_func(self->tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes,
{1},
values);
}
} else {
// step3.2: Case for there are advanced indexing.
Expand All @@ -1692,9 +1679,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim, trans_dim;
std::vector<int> trans_back_dim;

int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;
int pos_of_new_dim = 0, rank_of_new_dim = 0;

paddle::Tensor transed_sub_tensor =
dealWithAdvancedIndex(sub_tensor,
Expand All @@ -1704,8 +1691,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim,
&trans_dim);
&rank_of_new_dim);

// Release gil and do tracing
py::gil_scoped_release release;
Expand All @@ -1728,10 +1714,6 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
}
}

if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) {
value_tensor = transpose_ad_func(value_tensor, trans_dim);
}

// TODO(zoooo0820) 1.Using inplace version index_put
// 2.Remove following code after backward bug fixed.
transed_sub_tensor = assign_ad_func(transed_sub_tensor);
Expand Down
Loading

0 comments on commit e292842

Please sign in to comment.