Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape] Add helper function MakeGenerateShapeOpAttribute #60512

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include <unordered_set>
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"

Expand Down Expand Up @@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName(
};
}

namespace {

bool IsAtomicImpl(int64_t) { return true; }

bool IsAtomicImpl(const std::string&) { return true; }

bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Broadcast<symbol::DimExpr>&) { return false; }

bool IsAtomic(const symbol::DimExpr& dim_expr) {
return std::visit([](const auto& impl) { return IsAtomicImpl(impl); },
dim_expr.variant());
}

bool InputDimExprsAllSupported(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
const auto& AllSupported =
[](const std::vector<symbol::DimExpr>& dim_exprs) -> bool {
for (const auto& dim_expr : dim_exprs) {
if (!IsAtomic(dim_expr)) return false;
}
return true;
};
for (const auto& input_tensor : input_tensors) {
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
if (!AllSupported(dim_exprs.shape())) return false;
if (dim_exprs.data().has_value()) {
if (!AllSupported(dim_exprs.data().value())) return false;
}
}
return true;
}

void ConvertDimExprToAttributes(pir::IrContext* ir_context,
const std::vector<symbol::DimExpr>& dim_exprs,
std::vector<pir::Attribute>* attrs) {
attrs->clear();
attrs->reserve(dim_exprs.size());
for (const auto& dim_expr : dim_exprs) {
attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr));
}
}

void CollectSymbolNames(const symbol::DimExpr& dim_expr,
std::set<std::string>* ret);

void CollectSymbolNamesImpl(const int64_t& dim_expr,
std::set<std::string>* ret) {
// do nothing.
}

void CollectSymbolNamesImpl(const std::string& dim_expr,
std::set<std::string>* ret) {
ret->insert(dim_expr);
}

template <typename T>
void CollectSymbolNamesImplForUnary(const T& dim_expr,
std::set<std::string>* ret) {
const auto& [operand] = *dim_expr;
CollectSymbolNames(operand, ret);
}

void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

template <typename T>
void CollectSymbolNamesImplForVariadic(const T& dim_expr,
std::set<std::string>* ret) {
const auto& operands = *(dim_expr.operands);
for (const auto& operand : operands) {
CollectSymbolNames(operand, ret);
}
}

void CollectSymbolNamesImpl(const symbol::Add<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Min<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Broadcast<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNames(const symbol::DimExpr& dim_expr,
std::set<std::string>* ret) {
return std::visit(
[&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); },
dim_expr.variant());
}

void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs,
std::set<std::string>* ret) {
for (const auto& dim_expr : dim_exprs) {
CollectSymbolNames(dim_expr, ret);
}
}

template <typename SymbolBindingsT>
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
const std::set<std::string>& symbol_names,
int in_tensor_idx,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
++in_tensor_dim_idx) {
const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx);
CHECK(IsAtomic(dim_expr));
if (!dim_expr.isa<std::string>()) continue;
const auto& sym_name = dim_expr.dyn_cast<std::string>();
if (symbol_names.find(sym_name) == symbol_names.end()) continue;
symbol_bindings->emplace_back(SymbolBindingsT{
/*.symbol_name=*/sym_name,
/*.input_tensor_idx=*/in_tensor_idx,
/*.input_tensor_dim_idx=*/in_tensor_dim_idx,
});
}
}

void GenerateSymbolBindings(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors,
const std::set<std::string>& symbol_names,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
for (int i = 0; i < input_tensors.size(); ++i) {
const auto& input_tensor = input_tensors.at(i);
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
AppendSymbolBindings<GenerateShapeOp::ShapeSymbolBinding>(
dim_exprs.shape(), symbol_names, i, symbol_bindings);
if (dim_exprs.data().has_value()) {
AppendSymbolBindings<GenerateShapeOp::DataSymbolBinding>(
dim_exprs.shape(), symbol_names, i, symbol_bindings);
}
}
}

std::vector<pir::Value> GetMinimalInputs(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
std::unordered_set<symbol::DimExpr> handdled_dim_exprs;
std::unordered_set<pir::Value> first_occurred_input_tensors;
auto TryCollectFirstOcurredInput_tensor =
[&](pir::Value input_tensor,
const std::vector<symbol::DimExpr>& dim_exprs) {
for (const auto& dim_expr : dim_exprs) {
if (dim_expr.isa<int64_t>()) continue;
if (!handdled_dim_exprs.insert(dim_expr).second) {
first_occurred_input_tensors.insert(input_tensor);
}
}
};
for (pir::Value input_tensor : input_tensors) {
const auto& shape_or_data_dim_exprs =
ShapeOrDataDimExprs4Value(input_tensor);
if (shape_or_data_dim_exprs.data().has_value()) {
TryCollectFirstOcurredInput_tensor(
input_tensor, shape_or_data_dim_exprs.data().value());
}
TryCollectFirstOcurredInput_tensor(input_tensor,
shape_or_data_dim_exprs.shape());
}
std::vector<pir::Value> ret{};
ret.reserve(input_tensors.size());
for (pir::Value input_tensor : input_tensors) {
if (first_occurred_input_tensors.count(input_tensor) > 0) {
ret.emplace_back(input_tensor);
}
}
return ret;
}

} // namespace

bool MakeGenerateShapeOpAttribute(
pir::IrContext* ir_context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<symbol::DimExpr>& out_dim_exprs,
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minial_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
*minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs);
if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) {
VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure "
"they are handled by other passes";
return false;
}
// generate output_dim_expr_attrs
ConvertDimExprToAttributes(
ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs);
// generate symbol_bindings
std::set<std::string> symbol_names_in_out_dim_exprs{};
CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs);
GenerateSymbolBindings(ShapeOrDataDimExprs4Value,
*minial_inputs,
symbol_names_in_out_dim_exprs,
/*out*/ symbol_bindings);
return true;
}

} // namespace cinn::dialect
15 changes: 15 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#pragma once

#include <functional>
#include <optional>
#include <vector>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
Expand Down Expand Up @@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName(
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;

// Returns true if success.
bool MakeGenerateShapeOpAttribute(
pir::IrContext* ir_context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<symbol::DimExpr>& out_dim_exprs,
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minial_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings);

} // namespace cinn::dialect
Loading