Skip to content

Commit

Permalink
[CINN] Make Resize Buffer Safer (#59014)
Browse files Browse the repository at this point in the history
Make Resize Buffer Safer, the old buffer resize didn't consider load, current we add support for it

This PR also contain some code of safer UpdateBufferAxis of #59209

We will also clean it in the 59209 PR
  • Loading branch information
zhhsplendid authored Dec 5, 2023
1 parent 833f556 commit 5c70f3e
Show file tree
Hide file tree
Showing 16 changed files with 805 additions and 19 deletions.
7 changes: 5 additions & 2 deletions paddle/cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ TEST(CAS, SimplifySum) {
// z + 1 + y + 3 + x + 0 + zx
auto u4 = CasSimplify(
Sum::Make({z, Expr(1), y, Expr(3), x, Expr(0), Product::Make({z, x})}));
// (-1 * x) + x
auto u5 = CasSimplify(Sum::Make({Product::Make({Expr(-1), x}), x}));
// x2 + 3zy + -3*yz + -2x + 1
auto u5 = CasSimplify(Sum::Make({Product::Make({x, Expr(2)}),
auto u6 = CasSimplify(Sum::Make({Product::Make({x, Expr(2)}),
Product::Make({z, y, Expr(3)}),
Product::Make({Expr(-3), y, z}),
Product::Make({Expr(-2), x}),
Expand All @@ -86,7 +88,8 @@ TEST(CAS, SimplifySum) {
EXPECT_EQ(GetStreamCnt(CasSimplify(u2)), "(x + y + z)");
EXPECT_EQ(GetStreamCnt(u3), "(1 + x + y + z + (x * z))");
EXPECT_EQ(GetStreamCnt(u4), "(4 + x + y + z + (x * z))");
EXPECT_EQ(GetStreamCnt(u5), "1");
EXPECT_EQ(GetStreamCnt(u5), "0");
EXPECT_EQ(GetStreamCnt(u6), "1");
}

TEST(CAS, SimplifyProduct) {
Expand Down
15 changes: 15 additions & 0 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,21 @@ struct BindInfo {
return offset >= 0 && offset < 3 &&
(for_type == ForType::GPUThread || for_type == ForType::GPUBlock);
}

friend std::ostream& operator<<(std::ostream& os, const BindInfo& bind_info) {
CHECK(bind_info.valid()) << "Make invalid BindInfo to stream";
char axis_name = 'x' + bind_info.offset;
std::string prefix =
bind_info.for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx.";
os << prefix + axis_name;
return os;
}

operator std::string() const {
std::ostringstream os;
os << *this;
return os.str();
}
};

struct ForBase {
Expand Down
27 changes: 25 additions & 2 deletions paddle/cinn/ir/utils/ir_replace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ using utils::GetStreamCnt;

namespace {

struct IrReplaceMutator : ir::IRMutator<Expr*> {
struct IrReplaceVarBroadcastMutator : ir::IRMutator<Expr*> {
std::set<ir::IrNodeTy> valid_nodetys{
{ir::IrNodeTy::Broadcast, ir::IrNodeTy::_Var_}};

IrReplaceMutator(ir::Expr from, Expr to)
IrReplaceVarBroadcastMutator(ir::Expr from, Expr to)
: from_(from), to_(to), from_repr_(GetStreamCnt(from)) {
CHECK(valid_nodetys.count(from->node_type()))
<< "Not valid node type got " << from->node_type();
Expand All @@ -59,8 +59,31 @@ struct IrReplaceMutator : ir::IRMutator<Expr*> {
Expr to_;
};

struct IrReplaceMutator : ir::IRMutator<Expr*> {
IrReplaceMutator(ir::Expr from, Expr to)
: from_(from), to_(to), from_repr_(GetStreamCnt(from)) {}

void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const Expr* op, Expr* expr) override {
ir::IRMutator<>::Visit(expr, expr);
if (from_repr_ == GetStreamCnt(*expr)) {
*expr = ir::ir_utils::IRCopy(to_);
}
}

std::string from_repr_;
ir::Expr from_;
Expr to_;
};

} // namespace

void IrReplaceVarBroadcast(ir::Expr* expr, ir::Expr from, ir::Expr to) {
CHECK(expr);
IrReplaceVarBroadcastMutator(from, to)(expr);
}

void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to) {
CHECK(expr);
IrReplaceMutator(from, to)(expr);
Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/ir/utils/ir_replace.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ namespace cinn {
namespace ir {
namespace ir_utils {

//! Replace the variable \p v to expression \p e in expression \p expr.
//! Replace the variable \p from to expression \p to in expression \p expr.
void IrReplaceVarBroadcast(ir::Expr* expr, ir::Expr from, ir::Expr to);

//! Replace the Expr \p from to expression \p to in expression \p expr.
void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to);

} // namespace ir_utils
} // namespace ir
} // namespace cinn
5 changes: 4 additions & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ gather_srcs(
cast_bool_to_int8.cc
var_mod_simplify.cc
remove_schedule_block.cc
replace_cross_thread_reduction.cc)
replace_cross_thread_reduction.cc
replace_mod_to_max.cc
resize_buffer.cc
update_buffer_axis_pass.cc)

if(WITH_CUDA)
gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc)
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/optim/eliminate_broadcast_in_forloop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
std::tie(let_expr, tmp) = CreateTmpLet(broadcast);
let_exprs.push_back(let_expr);

cinn::ir::ir_utils::IrReplace(expr, broadcast, tmp);
cinn::ir::ir_utils::IrReplaceVarBroadcast(expr, broadcast, tmp);
}

// insert the let expressions to the outer forloop.
Expand Down
43 changes: 43 additions & 0 deletions paddle/cinn/optim/replace_mod_to_max.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/replace_mod_to_max.h"

#include <unordered_map>

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn {
namespace optim {

class ReplaceModToMaxMutator : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const ir::Mod* op, ir::Expr* expr) override {
ir::Mod* node = expr->As<ir::Mod>();
Expr base = node->operand(1);
*expr = ir::Sub::Make(base, Expr(1));
}
};

void ReplaceModToMax(ir::Expr* expr) {
ReplaceModToMaxMutator mutator;
mutator(expr);
}

} // namespace optim
} // namespace cinn
33 changes: 33 additions & 0 deletions paddle/cinn/optim/replace_mod_to_max.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>

#include "paddle/cinn/ir/ir.h"

namespace cinn {
namespace optim {

/**
* Given Expr AST, analyze the range of N % M will return M - 1.
* This function is used to replace the mod operation with max.
*
* Note: the replacement will change the semantics of the AST.
* It is only used for analyze, not computing.
*/
void ReplaceModToMax(ir::Expr* expr);

} // namespace optim
} // namespace cinn
Loading

0 comments on commit 5c70f3e

Please sign in to comment.