Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into C37
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored Mar 2, 2024
2 parents 2cc3bac + 6fccb8f commit 8f84ff6
Show file tree
Hide file tree
Showing 75 changed files with 1,008 additions and 422 deletions.
86 changes: 73 additions & 13 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"

PD_DECLARE_bool(cinn_new_group_scheduler);
PD_DECLARE_bool(group_schedule_tiling_first);
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
Expand Down Expand Up @@ -93,9 +94,21 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
std::vector<ir::Expr> iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
VLOG(4) << "FLAGS_group_schedule_tiling_first = "
<< FLAGS_group_schedule_tiling_first;
std::vector<Var> axis_vars = cinn::common::GenDefaultAxis(axis_len);
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
VLOG(4) << "ast gen: tensor init_body is " << init_body;
for (int i = 0; i < shape.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
Expand All @@ -105,29 +118,41 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
/*is_reduce = */ false));
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back());
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
VLOG(4) << "iter_value.size() and block_vars.size() is "
<< iter_values.size() << " " << block_vars.size();
init_body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(
block_vars, {}, {}, reduce_init_name, init_body));

// For the remaining reduce axis, make reduce body
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
ir::Expr reduce_body =
ConvertReduceBody(tensor->body(), tensor, axis_exprs);

VLOG(4) << "ast gen: reduce body is " << reduce_body;

// create schedule block itervars, i0,i1...
std::vector<ir::Var> reduce_block_vars;
std::vector<ir::Expr> reduce_iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
Expand All @@ -136,12 +161,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
reduce_axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
reduce_iter_values.push_back(Expr(0));
} else {
reduce_iter_values.push_back(axis_vars[i]);
}
}
VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body;
for (int i = 0; i < reduce_axis.size(); ++i) {
int count = shape.size() + i;
reduce_block_vars.push_back(
Expand All @@ -155,14 +181,43 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
}

int non_zero_axis_size = 0;
for (int i = 0; i < axis.size(); ++i) {
if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
continue;
if (FLAGS_group_schedule_tiling_first) {
std::vector<ir::Var> non_reduce_axis_vars = [&]() {
std::vector<ir::Var> res;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (!is_keep_dim) {
res.push_back(axis[i]);
}
}
return res;
}();
for (int i = 0; i < non_reduce_axis_vars.size(); ++i) {
optim::ReplaceVarWithExpr(
&reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]);
++non_zero_axis_size;
}
optim::ReplaceVarWithExpr(
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
++non_zero_axis_size;
} else {
for (int i = 0; i < axis.size(); ++i) {
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
continue;
}
optim::ReplaceVarWithExpr(
&reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]);
++non_zero_axis_size;
}
}

VLOG(4) << "to replace : " << non_zero_axis_size << " "
<< reduce_block_vars.size();
for (auto i = 0; i < reduce_block_vars.size(); i++) {
VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i];
}
for (auto i = 0; i < reduce_axis.size(); i++) {
VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i];
}
VLOG(4) << "before replace body: " << reduce_body;
for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) {
optim::ReplaceVarWithExpr(&reduce_body,
reduce_axis[i - non_zero_axis_size],
Expand All @@ -185,7 +240,12 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
continue;
}
if (!FLAGS_group_schedule_tiling_first && !FLAGS_cinn_bucket_compile &&
shape[i] == Expr(1)) {
continue;
}
ir::Var loop_var = axis[i];
Expand All @@ -210,7 +270,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage {
static std::size_t HashValue(const ParamKey& key) {
size_t hash_value = std::hash<std::string>{}(key.group_id);

for (auto op : key.ops) {
hash_value =
pir::detail::hash_combine(hash_value, std::hash<void*>()(op));
}

for (auto d : key.loop_ranges) {
hash_value =
pir::detail::hash_combine(hash_value, std::hash<int64_t>()(d));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ cinn::dialect::GroupInfo BuildGroupInfo(
const GroupClusterNode& node,
const std::unordered_map<::pir::Operation*, std::vector<ScheduleInfoNode>>&
new_align_info) {
cinn::dialect::GroupInfo group_info({});
cinn::dialect::GroupInfo group_info(vec_new_op_list);
group_info.group_id = BuildGroupId(vec_new_op_list);
group_info.loop_ranges = node.loop_ranges;
group_info.reduce_axis = node.reduce_axis;
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/pe/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ Tensor DoReduce(const Tensor& tensor,
int indice_cnt = 0;
int reduce_cnt = 0;

// Set keepdim flags of indices.
if (tensor->shape.size() == indices.size()) {
for (const auto& i : real_axes) {
VLOG(4) << "Set is_keepdim = true for var(" << i << ")";
indices[i].as_var_ref()->is_keepdim = true;
}
}

for (size_t i = 0; i < tensor->shape.size(); ++i) {
bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) !=
squeeze_axes.end();
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,13 @@ Expr _Var_::Make(Expr lower_bound,
Expr upper_bound,
const std::string &name,
bool is_reduce_axis,
bool is_symbolic_constant) {
bool is_symbolic_constant,
bool is_keepdim) {
auto *n = make_shared<_Var_>();
n->lower_bound = lower_bound;
n->upper_bound = upper_bound;
n->is_reduce_axis = is_reduce_axis;
n->is_keepdim = is_keepdim;
n->is_symbolic_constant = is_symbolic_constant;
n->name = name;
n->set_type(lower_bound.type());
Expand All @@ -233,6 +235,7 @@ Expr _Var_::Copy() const {
auto *n = make_shared<_Var_>();
n->name = name;
n->is_reduce_axis = is_reduce_axis;
n->is_keepdim = is_keepdim;
n->lower_bound = lower_bound;
n->upper_bound = upper_bound;
n->set_type(type());
Expand Down
15 changes: 10 additions & 5 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ struct _Var_ : public ExprNode<_Var_> {
std::string name;

bool is_reduce_axis{false};
bool is_keepdim{false};
bool is_symbolic_constant{false};
//! Lower bound and upper bound of a axis.
// @{
Expand All @@ -401,7 +402,8 @@ struct _Var_ : public ExprNode<_Var_> {
Expr upper_bound,
const std::string& name,
bool is_reduce,
bool is_symbolic_constant = false);
bool is_symbolic_constant = false,
bool is_keepdim = false);

void Verify() const override;

Expand All @@ -419,12 +421,14 @@ struct Var : public IrNodeRef {
Var(Expr lower_bound,
Expr upper_bound,
const std::string& name,
bool is_reduce = false)
: Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {}
bool is_reduce = false,
bool is_keepdim = false)
: Var(_Var_::Make(
lower_bound, upper_bound, name, is_reduce, false, is_keepdim)) {}
Var(int upper_bound, const std::string& name)
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {}
: Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false, false)) {}
Var(Expr upper_bound, const std::string& name)
: Var(_Var_::Make(Expr(0), upper_bound, name, false)) {}
: Var(_Var_::Make(Expr(0), upper_bound, name, false, false)) {}

operator Expr() { return Expr(get()); }
operator Expr() const {
Expand Down Expand Up @@ -977,6 +981,7 @@ struct ScheduleBlock : public ExprNode<ScheduleBlock> {
std::map<std::string, attr_t> attrs;
std::string name;
Expr body;
int32_t reduce_type{-1}; // 0 for warp reduce, 1 for block reduce

static Expr Make(const std::vector<Var>& iter_vars,
const std::vector<Expr>& read_buffers,
Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/lang/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ ir::Tensor Compute(const std::vector<Expr> &domain,
domain_without_reduce_axis,
op,
reduce_axis);
const auto set_keep_dim_for_tensor = [&]() {
for (int i = 0; i < _axis.size(); ++i) {
const auto &axis_var = _axis.at(i);
tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim;
}
};
set_keep_dim_for_tensor();
return tensor;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/pybind/ir/ir_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ void BindIrIr(py::module *m) {
ir::Expr,
const std::string &,
bool,
bool,
bool>(&ir::_Var_::Make))
.def("copy", &ir::_Var_::Copy);

Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ PD_DEFINE_bool(cinn_bucket_compile,
BoolFromEnv("FLAGS_cinn_bucket_compile", false),
"Whether to enable bucket compile for dynamic shape.");

PD_DEFINE_bool(group_schedule_tiling_first,
BoolFromEnv("FLAGS_group_schedule_tiling_first", false),
"Whether to enable new group scheduler tiling first strategy.");

PD_DEFINE_bool(cinn_use_common_subexpression_elimination,
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);
if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
VLOG(10) << "Change thread number to 1 because the topology sort order is "
"unique";
strategy_.num_threads_ = 1;
traced_ops_.clear();
for (auto *op_node : TopologySortOperations(*graph_)) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/fetch_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ FetchOpHandle::~FetchOpHandle() = default;

void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchOp. Unexpceted Error."));
"No nodes need to wait FetchOp. Unexpected Error."));
}

static void CheckDims(const framework::DDim &tensor_dims,
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
// TODO(inference): Now we only suppor dense_tensor cache, we may be
// TODO(inference): Now we only support dense_tensor cache, we may be
// support ScalarTensor, SparseTensor in future.
bool all_dense_tensor_input_{true};
for (auto& iter : Inputs()) {
Expand Down Expand Up @@ -2573,7 +2573,7 @@ Scope* OperatorWithKernel::PrepareData(
// for some situation like InferShape().
// In this situation We cannot skip Var analysis, as
// oneDNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// In such situation corresponding resized Var
// has to be created and registered
if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) &&
Expand Down Expand Up @@ -3193,7 +3193,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t i = 0; i < input_names.size(); ++i) {
auto it = ctx.inputs.find(input_names[i]);

// calcute the start and end index of the input tensors
// calculate the start and end index of the input tensors
size_t start_idx =
(i == 0 ? 0 : phi_kernel_context->InputRangeAt(i - 1).second);
// deal with optional here
Expand Down Expand Up @@ -3399,7 +3399,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
attr_iter,
Attrs().end(),
platform::errors::NotFound("(%s) is not found in AttributeMap when "
"buildind static KernelContext.",
"building static KernelContext.",
attr_names[i]));
switch (AttrTypeID(attr_iter->second)) {
case proto::AttrType::INTS: {
Expand Down Expand Up @@ -3473,7 +3473,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
RuntimeAttrs().end(),
platform::errors::NotFound(
"(%s) is not found in AttributeMap when "
"buildind static KernelContext.",
"building static KernelContext.",
attr_names[i]));
}

Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,15 +639,15 @@ void InitP2P(const std::vector<platform::Place> &places) {
for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) {
if (devices[i] == devices[j]) continue;
int can_acess = -1;
int can_access = -1;
#ifdef PADDLE_WITH_HIP
hipError_t ret =
hipDeviceCanAccessPeer(&can_acess, devices[i], devices[j]);
if (ret != hipSuccess || can_acess != 1) {
hipDeviceCanAccessPeer(&can_access, devices[i], devices[j]);
if (ret != hipSuccess || can_access != 1) {
#else
cudaError_t ret =
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]);
if (ret != cudaSuccess || can_acess != 1) {
cudaDeviceCanAccessPeer(&can_access, devices[i], devices[j]);
if (ret != cudaSuccess || can_access != 1) {
#endif
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
Expand Down
Loading

0 comments on commit 8f84ff6

Please sign in to comment.