Skip to content

Commit

Permalink
【NewIR】add add_n_grad and split_with_num_grad (Split op ) (#56873)
Browse files Browse the repository at this point in the history
* add reference of lbfgs

* add reference of lbfgs

* tmp

* split gen modify

* fix conflict

* add split

* fix bug

* fix bug

* test split

* add meta tensor

* refine code

* fix bug

* fix bug

* fix comflict

* Call _C_ops.sum in new ir

* modify concat kernel choose

* modify ci

* modify sum zero_dim optest

* modify split_with_num api

* modify split -1

* modify split test

* fix bug

* xxx

* delete extra modify

* add add_n

* tmp

* add split_with_num_grad

* modify split grad num bug

* modify ci

* modify ci

* clear code

* modify

* recover

* add add_n stop_gradient infer

* modify opreslut to value

* fix conflict

* recover to aviod conflict

* recover to aviod conflict

* modify opreslut to value

* recover complex tanh

* modify add_n optest

* skip bfp16

* modify split bf16

* fix conflict

* delete print

---------

Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: 0x45f <wangzhen45@baidu.com>
  • Loading branch information
3 people authored Sep 19, 2023
1 parent c792e47 commit 0029a24
Show file tree
Hide file tree
Showing 17 changed files with 239 additions and 38 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'bool': 'pir::BoolAttribute',
}

PD_MANUAL_OP_LIST = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'}
PD_MANUAL_OP_LIST = {
'add_n',
'add_n_',
'add_n_with_kernel',
'split_grad',
}


def to_phi_and_fluid_op_name(op_item):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"add",
"concat",
"split",
"split_with_num",
"gelu",
"matmul",
"erf",
Expand All @@ -45,6 +46,7 @@
'layer_norm',
'reshape',
'cast',
"scale",
'softmax',
'silu',
'elementwise_pow',
Expand All @@ -53,7 +55,6 @@
'slice',
'transpose',
'slice_double',
'scale',
]
vjp_interface_implementation_gen_op_list = [
"tanh",
Expand All @@ -63,6 +64,7 @@
"add",
"concat",
"split",
"split_with_num",
"gelu",
"matmul",
"erf",
Expand All @@ -77,6 +79,7 @@
'layer_norm',
'reshape',
'cast',
"scale",
'softmax',
'silu',
'elementwise_pow',
Expand All @@ -85,5 +88,4 @@
'slice',
'transpose',
'slice_double',
'scale',
]
30 changes: 30 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ pir::OpResult builtin_combine(const std::vector<pir::Value>& x) {
return combine_op.out();
}

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
pir::Value out_grad) {
std::vector<pir::OpResult> inputs_grad;
for (size_t i = 0; i < inputs.size(); i++) {
paddle::dialect::ScaleOp scale_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::ScaleOp>(
out_grad, 1.0, 0.0, true);
inputs_grad.push_back(scale_op.result(0));
}
return inputs_grad;
}

pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype,
const Place& place) {
Expand Down Expand Up @@ -76,5 +88,23 @@ pir::OpResult embedding_grad(pir::Value x,
}
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis) {
auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
out_grad_combine_op.out(), axis);
return split_grad_op.result(0);
}

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::Value axis) {
auto out_grad_combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(out_grad);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
out_grad_combine_op.out(), axis);
return split_grad_op.result(0);
}
} // namespace dialect
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace dialect {

pir::OpResult builtin_combine(const std::vector<pir::Value>& x);

std::vector<pir::OpResult> add_n_grad(std::vector<pir::Value> inputs,
pir::Value out_grad);

pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});
Expand All @@ -41,5 +44,9 @@ pir::OpResult embedding_grad(pir::Value x,
int64_t padding_idx = -1,
bool sparse = false);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad, int axis);

pir::OpResult split_with_num_grad(std::vector<pir::Value> out_grad,
pir::Value axis);
} // namespace dialect
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -148,6 +149,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) {
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef GET_MANUAL_OP_LIST
#undef GET_MANUAL_OP_LIST
paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp

#else

#pragma once
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
Expand All @@ -36,7 +31,10 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp
namespace paddle {
namespace dialect {

class AddNOp : public pir::Op<AddNOp, OpYamlInfoInterface, InferMetaInterface> {
class AddNOp : public pir::Op<AddNOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::VjpInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.add_n"; }
Expand All @@ -51,6 +49,10 @@ class AddNOp : public pir::Op<AddNOp, OpYamlInfoInterface, InferMetaInterface> {
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation *op,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
};

class AddN_Op : public pir::Op<AddN_Op,
Expand Down Expand Up @@ -201,4 +203,3 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
#endif
44 changes: 44 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,49 @@ namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> AddNOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddNOp op_obj = op->dyn_cast<AddNOp>();

VLOG(6) << "Prepare inputs of add_n_grad";

pir::CombineOp combine_op_obj = op_obj.inputs()
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<pir::CombineOp>();
std::vector<Tensor> inputs;
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {
inputs.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}

Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));

VLOG(6) << "Vjp prepare Prepare attributes of add_n_grad";

VLOG(6) << "Vjp prepare call add_n's vjp inteface";

std::vector<std::vector<Tensor>> tensor_res =
primitive::add_n_vjp(inputs, out_grad, stop_gradients);

VLOG(6) << "Vjp prepare stop gradient of add_n_grad";

std::vector<std::vector<pir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->value()
.dyn_cast<pir::OpResult>();
}
}
}
return res;
}

} // namespace dialect
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"

namespace paddle {
namespace primitive {
Expand All @@ -28,6 +29,10 @@ using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

template <typename T>
std::vector<Tensor> add_n_grad(const std::vector<Tensor>& x,
const Tensor& out_grad);

} // namespace backend
} // namespace primitive
} // namespace paddle
22 changes: 22 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
Expand All @@ -22,6 +24,26 @@ namespace primitive {
namespace backend {

using LazyTensor = paddle::primitive::LazyTensor;
template <>
std::vector<Tensor> add_n_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad) {
std::vector<pir::Value> x_res(x.size());
std::transform(x.begin(), x.end(), x_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->value();
});
pir::Value out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())->value();
auto op_res = paddle::dialect::add_n_grad(x_res, out_grad_res);

std::vector<Tensor> x_grad(op_res.size());
std::transform(op_res.begin(),
op_res.end(),
x_grad.begin(),
[](const pir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return x_grad;
}

} // namespace backend
} // namespace primitive
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'sum_grad',
'concat_grad',
'split_grad',
'split_with_num_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
Expand Down Expand Up @@ -97,6 +98,7 @@
'sum_grad',
'concat_grad',
'split_grad',
'split_with_num_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
Expand Down Expand Up @@ -145,6 +147,7 @@
'slice',
'layer_norm_grad',
'embedding_grad',
'sqrt',
'uniform',
]

Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,21 @@
#include "paddle/pir/core/operation.h"

namespace paddle {
namespace primitive {} // namespace primitive
namespace primitive {

std::vector<std::vector<paddle::Tensor>> add_n_vjp(
const std::vector<paddle::Tensor>& x,
const Tensor& out_grad,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg : stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
auto op_res = backend::add_n_grad<LazyTensor>(x, out_grad);
vjp_res[0] = op_res;
vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients);
return vjp_res;
}

} // namespace primitive
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ namespace paddle {
namespace primitive {

using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<paddle::Tensor>> add_n_vjp(
const std::vector<paddle::Tensor>& x,
const Tensor& out_grad,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
out->set_dims(in_dim);
}
out->share_lod(*x[0]);
out->set_dtype(x[0]->dtype());
}

// TODO(YuanRisheng) This InferMeta is used in Fluid
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
union_op_flags[i] = False
intersection_op_flags[i] = False

# some inputs in no_grad_set but its next op is effective,
# add their defining op here.
total_ops_list = list(total_ops)
for i, op in enumerate(total_ops_list):
if union_op_flags[i] is False:
for result in op.results():
if result.has_one_use():
next_op = result.first_use().owner()
if (
next_op in total_ops
and union_op_flags[total_ops_list.index(next_op)]
is True
):
union_op_flags[i] = True
else:
continue

effective_ops = [
total_ops[i] for i in range(len(total_ops)) if intersection_op_flags[i]
]
Expand Down
Loading

0 comments on commit 0029a24

Please sign in to comment.