Skip to content

Commit

Permalink
fix set value grad
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 15, 2023
1 parent 29eaaa1 commit 8a5f39f
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 23 deletions.
35 changes: 15 additions & 20 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");
op->SetType("set_value_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}
op->SetAttrMap(this->Attrs());

op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

if (this->HasInput("ValueTensor")) {
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,14 @@

- backward_op : set_value_grad
forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out)
args : (Tensor out_grad)
args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
output : Tensor(x_grad)
infer_meta:
func: UnchangedInferMeta
param: [out_grad]
kernel:
func: assign
param: [out_grad]
func: set_value_with_scalar_grad
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]

- backward_op : set_value_with_tensor_grad
forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out)
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/kernels/cpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
CPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::bfloat16,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
17 changes: 17 additions & 0 deletions paddle/phi/kernels/gpu/set_value_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
GPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
93 changes: 93 additions & 0 deletions paddle/phi/kernels/impl/set_value_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,97 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
const int rank = out_grad.dims().size();

switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_with_scalar_grad's input should be less than "
"7, but "
"received %d.",
rank));
}
}
} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/kernels/set_value_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ void SetValueGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* value_grad);

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad);
} // namespace phi
103 changes: 103 additions & 0 deletions paddle/phi/kernels/xpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,100 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
const int rank = out_grad.dims().size();

switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_with_scalar_grad's input should be less than "
"7, but "
"received %d.",
rank));
}
}

} // namespace phi

PD_REGISTER_KERNEL(set_value_grad,
Expand All @@ -407,3 +501,12 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
int,
int64_t) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
XPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}

0 comments on commit 8a5f39f

Please sign in to comment.