Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into infer
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Jan 2, 2024
2 parents b576640 + a08580e commit 707caea
Show file tree
Hide file tree
Showing 484 changed files with 13,992 additions and 5,864 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.*
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_XPU_KP "Compile PaddlePaddle with BAIDU XPU compiler " OFF)
option(WITH_XPU_XFT "Compile PaddlePaddle with BAIDU XPU-XFT" OFF)
option(WITH_XPU_PLUGIN "Compile PaddlePaddle with BAIDU XPU plugin" OFF)
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library" OFF)
option(WITH_XPU_XHPC "Compile PaddlePaddle with BAIDU XPU-HPC library"
${WITH_XPU})
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231203")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231215")
set(XPU_XHPC_BASE_DATE "20231229")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
15 changes: 10 additions & 5 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ copy_part_of_thrid_party(inference_lib_dist ${PADDLE_INFERENCE_INSTALL_DIR})

set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")

if(WIN32)
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/common.*)
else()
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
endif()
copy(
inference_lib_dist
SRCS ${paddle_common_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)

if(WIN32)
if(WITH_STATIC_LIB)
set(paddle_inference_lib
Expand Down Expand Up @@ -268,11 +278,6 @@ else()
SRCS ${paddle_phi_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()
set(paddle_common_lib ${PADDLE_BINARY_DIR}/paddle/common/libcommon.*)
copy(
inference_lib_dist
SRCS ${paddle_common_lib}
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif()

copy(
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ gather_srcs(
python_interpreter_guard.cc
nvgpu_dev_info.cc
integer_set.cc
dim_expr_simplify.cc)
dim_expr_simplify.cc
dim_expr_converter.cc)

cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
Expand All @@ -49,4 +50,6 @@ endif()
if(NOT CINN_ONLY)
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
cinncore)
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
cinncore)
endif()
101 changes: 101 additions & 0 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"

namespace cinn::common {
using namespace symbol; // NOLINT

namespace {

struct DimExprToIrExprVisitor {
ir::Expr ConvertToIrExpr(const DimExpr& dim_expr) {
return std::visit(*this, dim_expr.variant());
}

ir::Expr operator()(const int64_t& dim) { return ir::Expr(dim); }

ir::Expr operator()(const std::string& dim_expr) {
Var x = ir::_Var_::Make(dim_expr, Int(64));
return x;
}

ir::Expr operator()(const Negative<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Add<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(0));
}
ir::Expr sum = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
sum = ir::Add::Make(sum, ConvertToIrExpr(operands->at(i)));
}
return sum;
}

ir::Expr operator()(const Mul<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(1));
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}

ir::Expr operator()(const Max<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr max = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
max = ir::Max::Make(max, ConvertToIrExpr(operands->at(i)));
}
return max;
}

ir::Expr operator()(const Min<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
CHECK(!operands->empty());
ir::Expr min = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
min = ir::Min::Make(min, ConvertToIrExpr(operands->at(i)));
}
return min;
}

ir::Expr operator()(const Broadcast<DimExpr>& dim_expr) {
LOG(FATAL)
<< "no support for converting from Broadcast<DimExpr> to ir::Expr";
}
};

} // namespace

ir::Expr DimExprConverter::ConvertToIrExpr(const DimExpr& dim_expr) const {
return DimExprToIrExprVisitor().ConvertToIrExpr(dim_expr);
}

} // namespace cinn::common
26 changes: 26 additions & 0 deletions paddle/cinn/common/dim_expr_converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/ir/ir.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace cinn::common {

struct DimExprConverter final {
ir::Expr ConvertToIrExpr(const symbol::DimExpr&) const;
};

} // namespace cinn::common
79 changes: 79 additions & 0 deletions paddle/cinn/common/dim_expr_converter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <sstream>

#include "gtest/gtest.h"

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn::common::test {

using namespace symbol; // NOLINT

TEST(Convert, AddExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Add<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Add::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Add::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, SubExpr) {
DimExpr dim_expr = DimExpr(4) - DimExpr("sym_0");
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 = ir::Sub::Make(ir::Expr(std::int64_t(0)),
ir::_Var_::Make("sym_0", Int(64)));
ir::Expr dst_expr = ir::Add::Make(ir::Expr(std::int64_t(4)), expr1);
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MulExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Mul<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

ir::Expr expr1 =
ir::Mul::Make(ir::Expr(std::int64_t(4)), ir::Expr(std::int64_t(5)));
ir::Expr dst_expr = ir::Mul::Make(expr1, ir::_Var_::Make("sym_0", Int(64)));
ASSERT_TRUE(MathEqual(src_expr, dst_expr));
}

TEST(Convert, MaxExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Max<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_max(cinn_max(4ll, 5ll), sym_0)");
}

TEST(Convert, MinExpr) {
List<DimExpr> num_lists{DimExpr(4), DimExpr(5), DimExpr("sym_0")};
DimExpr dim_expr{Min<DimExpr>{num_lists}};
ir::Expr src_expr = DimExprConverter().ConvertToIrExpr(dim_expr);

std::ostringstream stream;
stream << src_expr;
ASSERT_EQ(stream.str(), "cinn_min(cinn_min(4ll, 5ll), sym_0)");
}

} // namespace cinn::common::test
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
func : SliceRawInferMeta
kernel :
func : slice
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
}

private:
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_;
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_{nullptr};
};

class CinnGroupLoweringPass : public pir::PatternRewritePass {
Expand All @@ -237,7 +237,7 @@ class CinnGroupLoweringPass : public pir::PatternRewritePass {
}

private:
const std::shared_ptr<pir::ShapeConstraintIRAnalysis>& shape_analysis_;
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_{nullptr};
};

} // namespace
Expand Down
Loading

0 comments on commit 707caea

Please sign in to comment.