Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]Enhance CacheKey hash logic by considering input dtypes #50557

Merged
merged 11 commits into from
Feb 24, 2023
18 changes: 17 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc
100755 → 100644
Original file line number Diff line number Diff line change
@@ -45,10 +45,11 @@ CinnCacheKey::CinnCacheKey(

CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str,
GraphHashStrategy graph_hash)
: graph_hash_(graph_hash) {
this->SetKey(graph, input_shapes, arch_str);
this->SetKey(graph, input_shapes, input_dtypes, arch_str);
}

void CinnCacheKey::SetKey(
@@ -58,15 +59,24 @@ void CinnCacheKey::SetKey(
graph_hash_val_ = graph_hash_(graph);
for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims();
input_dtypes_[name_tensor.first] = name_tensor.second->dtype();
}
arch_str_ = arch_str;
}

void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str) {
PADDLE_ENFORCE_EQ(
input_shapes.size(),
input_dtypes.size(),
platform::errors::PreconditionNotMet(
"Required input_shapes has same length with input_dtypes."));

graph_hash_val_ = graph_hash_(graph);
input_shapes_ = input_shapes;
input_dtypes_ = input_dtypes;
arch_str_ = arch_str;
}

@@ -85,10 +95,16 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
for (const auto& name_shape : key.input_shapes_) {
has_str << name_shape.first;
has_str << std::hash<phi::DDim>()(name_shape.second);
PADDLE_ENFORCE_NE(key.input_dtypes_.find(name_shape.first),
key.input_dtypes_.end(),
platform::errors::PreconditionNotMet(
"%s is not in key.input_dtypes_.", name_shape.first));
has_str << phi::DataTypeToString(key.input_dtypes_.at(name_shape.first));
}

has_str << key.graph_hash_val_;
has_str << key.arch_str_;
VLOG(4) << "CinnCacheKey : " << has_str.str();
return std::hash<std::string>()(has_str.str());
}

8 changes: 7 additions & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_cache_key.h
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"

namespace paddle {
@@ -45,6 +46,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash);
CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str,
GraphHashStrategy graph_hash);

@@ -56,6 +58,7 @@ class CinnCacheKey {
const std::string& arch_str);
void SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& input_shapes,
const std::map<std::string, DataType>& input_dtypes,
const std::string& arch_str);

bool operator==(const CinnCacheKey& other) const;
@@ -69,6 +72,7 @@ class CinnCacheKey {
GraphHashStrategy graph_hash_;
size_t graph_hash_val_;
std::map<std::string, DDim> input_shapes_;
std::map<std::string, DataType> input_dtypes_;
std::string arch_str_;
};

@@ -84,8 +88,10 @@ class CinnCacheKey {
\
NAME(const ir::Graph& graph, \
const std::map<std::string, DDim>& input_shapes, \
const std::map<std::string, DataType>& input_dtypes, \
const std::string& arch_str) \
: CinnCacheKey(graph, input_shapes, arch_str, HashGraph) {} \
: CinnCacheKey( \
graph, input_shapes, input_dtypes, arch_str, HashGraph) {} \
\
private: \
static size_t HashGraph(const ir::Graph& graph); \
37 changes: 27 additions & 10 deletions paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc
Original file line number Diff line number Diff line change
@@ -39,29 +39,35 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByStructure) {
x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = {
{"X", tensor_pointer}};

DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};

CinnCacheKeyByStructure cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByStructure cache_key1(empty_graph, feed_shapes, "x86");
CinnCacheKeyByStructure cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1);

CinnCacheKeyByStructure cache_key2(graph, feed_shapes, "x86");
CinnCacheKeyByStructure cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKeyByStructure cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByStructure cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByStructure cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);

CinnCacheKeyByStructure cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByStructure cache_key6(
empty_graph, std::map<std::string, DDim>(), "unk");
CinnCacheKeyByStructure cache_key6(empty_graph,
std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6);

EXPECT_NE(cache_key1, cache_key3);
@@ -112,6 +118,7 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {
x->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph(program);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
@@ -120,21 +127,29 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKeyByAddress) {

DDim ddim = phi::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
std::map<std::string, DataType> feed_dtypes = {{"X", fp32}};
std::map<std::string, DataType> new_dtypes = {{"X", DataType::FLOAT64}};

CinnCacheKeyByAddress cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKeyByAddress cache_key1(empty_graph, feed_shapes, "x86");
CinnCacheKeyByAddress cache_key1(
empty_graph, feed_shapes, feed_dtypes, "x86");
EXPECT_EQ(cache_key0, cache_key1);

CinnCacheKeyByAddress cache_key2(graph, feed_shapes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKeyByAddress cache_key7(empty_graph, feed_shapes, new_dtypes, "x86");
EXPECT_NE(cache_key1, cache_key7);

CinnCacheKeyByAddress cache_key2(graph, feed_shapes, feed_dtypes, "x86");
CinnCacheKeyByAddress cache_key3(graph, feed_shapes, feed_dtypes, "nvgpu");
CinnCacheKeyByAddress cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);

CinnCacheKeyByAddress cache_key5(
empty_graph, std::map<std::string, const phi::DenseTensor *>(), "unk");
CinnCacheKeyByAddress cache_key6(
empty_graph, std::map<std::string, DDim>(), "unk");
CinnCacheKeyByAddress cache_key6(empty_graph,
std::map<std::string, DDim>(),
std::map<std::string, DataType>(),
"unk");
EXPECT_EQ(cache_key5, cache_key6);

EXPECT_NE(cache_key1, cache_key3);
@@ -186,7 +201,9 @@ TEST(CinnCacheKeyTest, TestSameGraph) {
x2->SetType(proto::VarType::LOD_TENSOR);
ir::Graph graph2(program2);

DataType fp32 = DataType::FLOAT32;
phi::DenseTensor tensor;
tensor.set_type(fp32);
tensor.Resize({1, 2, 3});
const phi::DenseTensor *tensor_pointer = &tensor;
std::map<std::string, const phi::DenseTensor *> feed_tensors = {
2 changes: 2 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
@@ -78,9 +78,11 @@ const CinnCompiledObject &CinnCompiler::Compile(
CinnCacheKeyByStructure cur_key_by_struct;

if (!cache_by_address_.count(cur_key_by_address)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_address_.";
// generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
if (!cache_by_struct_.count(cur_key_by_struct)) {
VLOG(4) << "Not found CinnCompiledObject in cache_by_struct_.";
std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream);