Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Sep 15, 2023
2 parents d38b297 + ac77cff commit b883667
Show file tree
Hide file tree
Showing 362 changed files with 8,413 additions and 6,738 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ cppcoreguidelines-avoid-c-arrays,
cppcoreguidelines-c-copy-assignment-signature,
cppcoreguidelines-explicit-virtual-functions,
-cppcoreguidelines-init-variables,
-cppcoreguidelines-narrowing-conversions,
cppcoreguidelines-narrowing-conversions,
-cppcoreguidelines-no-malloc,
-cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-pro-type-member-init,
Expand Down
9 changes: 7 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
select = C,E,W
exclude =
./build,
# Exclude fluid directory
./python/paddle/base/**,
# Exclude third-party libraries
./third_party/**,
./python/paddle/utils/gast/**,
Expand All @@ -27,3 +25,10 @@ ignore =
per-file-ignores =
# These files need tabs for testing.
test/dygraph_to_static/test_error.py:E101,W191

# temp ignore base directory
python/paddle/base/*:
E713,
E712,
E266,
E714
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
- id: flake8
args: ["--config=.flake8"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
rev: v0.0.289
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
Expand Down
22 changes: 2 additions & 20 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) {
void CodeGenC::Visit(const ir::IfThenElse *op) {
str_ += "if (";
IrPrinter::Visit(op->condition);
str_ += ") {\n";
str_ += ") ";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
IrPrinter::Visit(op->true_case);
if (!op->true_case.As<ir::Block>()) str_ += ";";
str_ += "\n";

if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";

if (op->false_case.defined()) {
str_ += " else {\n";

if (!op->true_case.As<ir::Block>()) IncIndent();
DoIndent();
str_ += " else ";
IrPrinter::Visit(op->false_case);
if (!op->false_case.As<ir::Block>()) str_ += ";";
str_ += "\n";
if (!op->true_case.As<ir::Block>()) DecIndent();

DoIndent();
str_ += "}";
}
}
void CodeGenC::Visit(const ir::Block *op) {
Expand Down
4 changes: 0 additions & 4 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,8 @@ void test_simple_compute_at(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down Expand Up @@ -869,10 +867,8 @@ void test_compute_at0(void* _args, int32_t num_args)
for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) {
for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) {
if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) {
{
B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)];
C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)];
}
};
};
};
Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Expr For::Make(Var loop_var,
node->min = min;
node->extent = extent;
node->device_api = device_api;
node->body = body;
node->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
node->set_for_type(for_type);
node->set_vectorize_info(vector_info);
node->set_bind_info(bind_info);
Expand Down Expand Up @@ -346,6 +346,10 @@ std::vector<const Expr *> ScheduleBlockRealize::expr_fields() const {
}

Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) {
if (true_case.defined() && (!true_case.As<Block>()))
true_case = ir::Block::Make({true_case});
if (false_case.defined() && (!false_case.As<Block>()))
false_case = ir::Block::Make({false_case});
auto node = make_shared<IfThenElse>(condition, true_case, false_case);
return Expr(node);
}
Expand Down Expand Up @@ -513,7 +517,7 @@ Expr PolyFor::Make(Var iterator,
n->condition = condition;
n->inc = inc;
n->device_api = device_api;
n->body = body;
n->body = body.As<ir::Block>() ? body : ir::Block::Make({body});
n->set_for_type(for_type);
n->set_vectorize_info(vectorize_info);
n->set_bind_info(bind_info);
Expand Down
18 changes: 2 additions & 16 deletions paddle/cinn/ir/utils/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,26 +229,12 @@ void IrPrinter::Visit(const PolyFor *x) {
void IrPrinter::Visit(const IfThenElse *x) {
str_ += "if (";
Visit(x->condition);
str_ += ") {\n";
IncIndent();
DoIndent();
str_ += ") ";
Visit(x->true_case);
DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";

if (x->false_case.defined()) {
str_ += " else {\n";
IncIndent();

DoIndent();
str_ += " else ";
Visit(x->false_case);
str_ += "\n";

DecIndent();
DoIndent();
str_ += "}";
}
}
void IrPrinter::Visit(const Block *x) {
Expand Down
17 changes: 0 additions & 17 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,23 +306,6 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
}
}

void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<IfThenElse>();
Visit(&node->condition, &node->condition);
if (node->true_case.As<Block>() &&
(node->true_case.As<Block>()->stmts.size() == 1)) {
node->true_case = node->true_case.As<Block>()->stmts[0];
}
Visit(&node->true_case, &node->true_case);
if (node->false_case.defined()) {
if (node->false_case.As<Block>() &&
(node->false_case.As<Block>()->stmts.size() == 1)) {
node->false_case = node->false_case.As<Block>()->stmts[0];
}
Visit(&node->false_case, &node->false_case);
}
}

void Visit(const ScheduleBlock* op, Expr* expr) override {
auto* node = expr->As<ScheduleBlock>();
CHECK(node);
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/utils/attribute_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"

namespace cinn {
Expand Down Expand Up @@ -61,7 +62,9 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
if (item.first == ::pir::kStopGradientAttrName) {
continue;
} else if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
auto is_cpu =
item.second.dyn_cast<paddle::dialect::PlaceAttribute>().data() ==
phi::CPUPlace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(

if (tgt_splitted_shape.size() > 0) {
std::vector<DimTrans*> input_dims;
for (int64_t i = 0, n = src_dims.size(); i < n; i++) {
for (int i = 0, n = static_cast<int>(src_dims.size()); i < n; i++) {
int64_t in_dim = src_dims[i];
if (src_shape[in_dim] > 1) {
input_dims.emplace_back(new InputDim(in_dim));
Expand All @@ -141,7 +141,7 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
int64_t ninputs = static_cast<int64_t>(input_specs.size());
PADDLE_ENFORCE_EQ(
ninputs,
1,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,15 +949,15 @@ static bool CollectGradInformationFromOpInfo(
op_base_infos->resize(grad_node->size());
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
// Each OpBase
int index = std::distance(grad_node->begin(), iter);
int index = static_cast<int>(std::distance(grad_node->begin(), iter));
paddle::imperative::OpBase& op_base = *iter;
(*op_base_infos)[index].SetOpBaseType(op_base.Type());
}

/* ------ Get Grad ins/outs/attrs ---- */
VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
int index = static_cast<int>(std::distance(grad_node->begin(), iter));
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
auto* op_base_grad_attrs = (*op_base_infos)[index].GetMutableGradAttrs();
Expand Down Expand Up @@ -3160,7 +3160,8 @@ static void DygraphCodeGeneration(const std::string& output_dir,
op_info_map_need_gen.emplace(pair);
}

int each_cc_file_api_size = op_info_map_need_gen.size() / split_count;
int each_cc_file_api_size =
static_cast<int>(op_info_map_need_gen.size() / split_count);
if (op_info_map_need_gen.size() % split_count != 0) {
each_cc_file_api_size++;
}
Expand Down
25 changes: 24 additions & 1 deletion paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class {} : public egr::GradNodeBase {{
// Prepare Grad function call
{}
// Runtime check if we need next grad
{}
// Set DistAttr of Out Tensor for semi-auto parallel
{}
// Inplace Check
{}
Expand Down Expand Up @@ -529,6 +531,12 @@ class {} : public egr::GradNodeBase {{
if( !{}.empty() ) {}_optional = paddle::make_optional<std::vector<paddle::Tensor>>({});
"""

SET_GRAD_OUT_DIST_ATTR_TEMPLATE = """
if (IsRunAutoParallel()) {{
egr::EagerUtils::SetGradOutputDistAttr(out_metas, {}, {});
}}
"""

CHECK_BACKWARD_INPLACE_TEMPLATE = """
bool can_be_inplaced = false;
if ({}.initialized()) {{
Expand Down Expand Up @@ -1088,7 +1096,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
for name, (ttype, pos) in forward_inputs_position_map.items():
if name in need_pre_contiguous_set:
pre_contiguous_list.append(
f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())))))) : {name};"
f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl()))))), {name}.mutable_autograd_meta()) : {name};"
)
self.inputs_call_list_tmp[pos] = (
self.inputs_call_list_tmp[pos] + '_tmp'
Expand Down Expand Up @@ -2181,6 +2189,8 @@ def GenerateNodeDefinition(
)
grad_api_args = ["" for i in range(grad_api_args_len)]
get_grad_in_args_list = []
grad_api_out_args_list = []
fwd_positions_list = []

# Fill Grad Ins with Zero
fill_zero_str = ""
Expand Down Expand Up @@ -2388,6 +2398,8 @@ def GenerateNodeDefinition(
out_assign_str += f"{indent}*api_output_{out_index} = std::get<{out_index}>(api_output);\n"
else:
grad_api_args.append(f"api_output_{out_index}")
grad_api_out_args_list.append(f"api_output_{out_index}")
fwd_positions_list.append(f"{fwd_position}")
if inplace_grad_input_str in optional_inplace_var_name:
optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input: {inplace_grad_input_str}\";"
else:
Expand Down Expand Up @@ -2433,6 +2445,16 @@ def GenerateNodeDefinition(
composite_grad_api_args_str = ", ".join(grad_api_args)
composite_template_name = "<paddle::Tensor>"

# Set DistAttr Func Construct
set_out_dist_attr_str = ""
if not is_invoke_forward_api:
fwd_positions_str = "{" + ", ".join(fwd_positions_list) + "}"
grad_api_out_args_str = ", ".join(grad_api_out_args_list)
set_out_dist_attr_str = SET_GRAD_OUT_DIST_ATTR_TEMPLATE.format(
fwd_positions_str,
grad_api_out_args_str,
)

if is_invoke_forward_api:
autograd_api_out = "auto"
if (
Expand Down Expand Up @@ -2600,6 +2622,7 @@ def GenerateNodeDefinition(
get_grad_in_args_str,
grad_function_prepare_str,
compute_require_next_grad_str,
set_out_dist_attr_str,
inplace_check_str,
inplace_for_grad_outs_str,
self.backward_api_name,
Expand Down
Loading

0 comments on commit b883667

Please sign in to comment.