Skip to content

Commit

Permalink
[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)
Browse files Browse the repository at this point in the history
* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------

Co-authored-by: jiangcheng <thisjiang@qq.com>
  • Loading branch information
Aurelius84 and thisjiang authored Feb 24, 2023
1 parent 6e37a2c commit 21c6ecc
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 39 deletions.
16 changes: 8 additions & 8 deletions paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ pass_library(
errors
enforce)

cc_library(
cinn_cache_key
SRCS cinn_cache_key.cc
DEPS graph graph_helper lod_tensor proto_desc)
cc_library(
cinn_subgraph_detector
SRCS cinn_subgraph_detector.cc
DEPS graph graph_helper subgraph_detector lod_tensor proto_desc)
cc_library(
transform_desc
SRCS transform_desc.cc
Expand All @@ -24,6 +16,14 @@ cc_library(
transform_type
SRCS transform_type.cc
DEPS errors enforce cinn)
cc_library(
cinn_cache_key
SRCS cinn_cache_key.cc
DEPS graph graph_helper lod_tensor proto_desc transform_type)
cc_library(
cinn_subgraph_detector
SRCS cinn_subgraph_detector.cc
DEPS graph graph_helper subgraph_detector lod_tensor proto_desc)
cc_library(
cinn_graph_symbolization
SRCS cinn_graph_symbolization.cc
Expand Down
77 changes: 58 additions & 19 deletions paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "paddle/phi/core/ddim.h"

namespace paddle {
Expand All @@ -45,10 +46,11 @@ CinnCacheKey::CinnCacheKey(

CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str,
GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) {
this->SetKey(graph, input_shapes, arch_str);
this->SetKey(graph, input_shapes, input_dtypes, arch_str);
}

void CinnCacheKey::SetKey(
Expand All @@ -58,15 +60,24 @@ void CinnCacheKey::SetKey(
graph_hash_val_ = graph_hash_(graph);
for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims();
input_dtypes_[name_tensor.first] = name_tensor.second->dtype();
}
arch_str_ = arch_str;
}

void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str) {
PADDLE_ENFORCE_EQ(
input_shapes.size(),
input_dtypes.size(),
platform::errors::PreconditionNotMet(
"Required input_shapes has same length with input_dtypes."));

graph_hash_val_ = graph_hash_(graph);
input_shapes_ = input_shapes;
input_dtypes_ = input_dtypes;
arch_str_ = arch_str;
}

Expand All @@ -76,19 +87,26 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {

bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
return graph_hash_val_ == other.graph_hash_val_ &&
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
input_shapes_ == other.input_shapes_ &&
input_dtypes_ == other.input_dtypes_ && arch_str_ == other.arch_str_;
}

size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::ostringstream has_str;

for (const auto& name_shape : key.input_shapes_) {
has_str << name_shape.first;
has_str << std::hash<phi::DDim>()(name_shape.second);
has_str << name_shape.first << ",";
has_str << "[" << name_shape.second << "],";
PADDLE_ENFORCE_NE(key.input_dtypes_.find(name_shape.first),
key.input_dtypes_.end(),
platform::errors::PreconditionNotMet(
"%s is not in key.input_dtypes_.", name_shape.first));
has_str << key.input_dtypes_.at(name_shape.first) << ";";
}

has_str << key.arch_str_ << ",";
has_str << key.graph_hash_val_;
has_str << key.arch_str_;
VLOG(1) << "CinnCacheKey : " << has_str.str();
return std::hash<std::string>()(has_str.str());
}

Expand All @@ -101,24 +119,45 @@ size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {

// graph.Nodes() return unordered_set, here using set to avoid the same graph
// may return different result
std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare),
output_set(compare);
node_set.insert(graph.Nodes().begin(), graph.Nodes().end());

std::string hash_str;
for (ir::Node* n : node_set) {
hash_str.append(n->Name());

output_set.clear();
output_set.insert(n->outputs.begin(), n->outputs.end());
for (auto* out : output_set) {
hash_str.append(out->Name());
std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare);
for (ir::Node* node : graph.Nodes()) {
if (node->IsOp()) {
// only need cache graph with same op
node_set.insert(node);
}
}

static std::unordered_set<std::string> ignore_attr = {"op_callstack",
"op_device",
"op_namescope",
"op_role",
"op_role_var",
"with_quant_attr"};

std::ostringstream hash_str;
for (ir::Node* op : node_set) {
hash_str << op->Name() << ":";
hash_str << "input_num=" << op->inputs.size() << ",";
hash_str << "output_num=" << op->outputs.size() << ",";

const auto& attrs_unordered_map = op->Op()->GetAttrMap();
std::map<std::string, Attribute> attrs_map(attrs_unordered_map.begin(),
attrs_unordered_map.end());
for (const auto& attr : attrs_map) {
if (ignore_attr.count(attr.first)) {
continue;
}
const auto& attr_str = PaddleAttributeToString(attr.second);
if (!attr_str.empty()) {
hash_str << attr.first << "=" << attr_str << ",";
}
}
hash_str << ";";
}

VLOG(1) << "The hash graph:\n" << hash_str;
VLOG(1) << "The hash graph:\n" << hash_str.str();

size_t hash_val = std::hash<std::string>()(hash_str);
size_t hash_val = std::hash<std::string>()(hash_str.str());
VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
return hash_val;
}
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_cache_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"

namespace paddle {
Expand All @@ -45,6 +46,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash);
CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str,
GraphHashStrategy graph_hash);

Expand All @@ -56,6 +58,7 @@ class CinnCacheKey {
const std::string& arch_str);
void SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str);

bool operator==(const CinnCacheKey& other) const;
Expand All @@ -69,6 +72,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash_;
size_t graph_hash_val_;
std::map<std::string, DDim> input_shapes_;
std::map<std::string, DataType> input_dtypes_;
std::string arch_str_;
};

Expand All @@ -84,8 +88,10 @@ class CinnCacheKey {
\
NAME(const ir::Graph& graph, \
const std::map<std::string, DDim>& input_shapes, \
const std::map<std::string, DataType>& input_dtypes, \
const std::string& arch_str) \
: CinnCacheKey(graph, input_shapes, arch_str, HashGraph) {} \
: CinnCacheKey( \
graph, input_shapes, input_dtypes, arch_str, HashGraph) {} \
\
private: \
static size_t HashGraph(const ir::Graph& graph); \
Expand Down
37 changes: 27 additions & 10 deletions paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,35 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) {
x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = {
{"X", tensor_pointer}};

DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};

CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByStructure cache_key1(empty_graph, feed_shapes, "x86");
CinnCacheKeyByStructure cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1);

CinnCacheKeyByStructure cache_key2(graph, feed_shapes, "x86");
CinnCacheKeyByStructure cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKeyByStructure cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByStructure cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);

CinnCacheKeyByStructure cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByStructure cache_key6(
empty_graph, std::map<std::string, DDim>(), "unk");
CinnCacheKeyByStructure cache_key6(empty_graph,
std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6);

EXPECT_NE(cache_key1, cache_key3);
Expand Down Expand Up @@ -112,6 +118,7 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {
x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
Expand All @@ -120,21 +127,29 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {

DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};
std::map<std::string, DataType> new_dtypes = {{"X", DataType::FLOAT64}};

CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByAddress cache_key1(empty_graph, feed_shapes, "x86");
CinnCacheKeyByAddress cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1);

CinnCacheKeyByAddress cache_key2(graph, feed_shapes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKeyByAddress cache_key7(empty_graph, feed_shapes, new_dtypes, "x86");
EXPECT_NE(cache_key1, cache_key7);

CinnCacheKeyByAddress cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);

CinnCacheKeyByAddress cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByAddress cache_key6(
empty_graph, std::map<std::string, DDim>(), "unk");
CinnCacheKeyByAddress cache_key6(empty_graph,
std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6);

EXPECT_NE(cache_key1, cache_key3);
Expand Down Expand Up @@ -186,7 +201,9 @@ TEST(CinnCacheKeyTest, TestSameGraph) {
x2->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph2(program2);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
Expand Down Expand Up @@ -78,9 +79,11 @@ const CinnCompiledObject &CinnCompiler::Compile(
CinnCacheKeyByStructure cur_key_by_struct;

if (!cache_by_address_.count(cur_key_by_address)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_address_.";
// generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
if (!cache_by_struct_.count(cur_key_by_struct)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_struct_.";
std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream);
Expand Down Expand Up @@ -180,7 +183,8 @@ std::string CinnCompiler::VizGraph(const Graph &graph) const {
shape.begin(), shape.end(), shape_str.begin(), [](const auto &val) {
return std::to_string(val);
});
label += "\n" + string::join_strings(shape_str, ',');
label += "\n[" + string::join_strings(shape_str, ',') + "]";
label += "\n" + VarDataTypeToString(n->Var()->GetDataType());
}
dot.AddNode(
node_id,
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/framework/paddle2cinn/transform_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ ::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn(
#undef SET_DATA_TYPE_CASE_ITEM
}

std::string VarDataTypeToString(
const ::paddle::framework::proto::VarType::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case ::paddle::framework::proto::VarType::type__: \
return std::string(#type__); \
break;

switch (type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(COMPLEX64);
SET_DATA_TYPE_CASE_ITEM(COMPLEX128);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var data type"));
}
#undef SET_DATA_TYPE_CASE_ITEM
}

::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/paddle2cinn/transform_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ void TransformProgramDescFromCinn(
const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc,
framework::ProgramDesc* pb_desc);

// debug function
std::string VarDataTypeToString(
const ::paddle::framework::proto::VarType::Type& type);

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
10 changes: 10 additions & 0 deletions paddle/fluid/framework/paddle2cinn/transform_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ TEST(TransformProgramDesc, pb2cpp) {
ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize());
}

TEST(HelperFunction, VarDataTypeToString) {
const auto &pd_fp32_var = CreatePbVarDesc();
const auto &debug_fp32_string =
VarDataTypeToString(pd_fp32_var.GetDataType());
ASSERT_EQ(debug_fp32_string, std::string("FP32"));

ASSERT_EQ(VarDataTypeToString(::paddle::framework::proto::VarType::INT32),
std::string("INT32"));
}

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
Loading

0 comments on commit 21c6ecc

Please sign in to comment.