Skip to content

Commit

Permalink
add serialization for new field in event node (#43405)
Browse files Browse the repository at this point in the history
* add serialization for new field in event node

* fix a bug
  • Loading branch information
rainyfly authored Jun 13, 2022
1 parent 30b1063 commit 360b838
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 1 deletion.
76 changes: 76 additions & 0 deletions paddle/fluid/platform/profiler/dump/deserialization_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
device_node); // insert into runtime_node
}
}
// handle mem node
for (int mem_node_index = 0;
mem_node_index < host_node_proto.mem_nodes_size();
mem_node_index++) {
const MemTraceEventNodeProto& mem_node_proto =
host_node_proto.mem_nodes(mem_node_index);
MemTraceEventNode* mem_node = RestoreMemTraceEventNode(mem_node_proto);
host_node->AddMemNode(mem_node);
}
// handle op supplement node
for (int op_supplement_node_index = 0;
op_supplement_node_index <
host_node_proto.op_supplement_nodes_size();
op_supplement_node_index++) {
const OperatorSupplementEventNodeProto& op_supplement_node_proto =
host_node_proto.op_supplement_nodes(op_supplement_node_index);
OperatorSupplementEventNode* op_supplement_node =
RestoreOperatorSupplementEventNode(op_supplement_node_proto);
host_node->SetOperatorSupplementNode(op_supplement_node);
}
}
// restore parent-child relationship
for (auto it = child_parent_map.begin(); it != child_parent_map.end();
Expand Down Expand Up @@ -176,6 +196,62 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode(
return new HostTraceEventNode(host_event);
}

MemTraceEventNode* DeserializationReader::RestoreMemTraceEventNode(
const MemTraceEventNodeProto& mem_node_proto) {
const MemTraceEventProto& mem_event_proto = mem_node_proto.mem_event();
MemTraceEvent mem_event;
mem_event.timestamp_ns = mem_event_proto.timestamp_ns();
mem_event.addr = mem_event_proto.addr();
mem_event.type = static_cast<TracerMemEventType>(mem_event_proto.type());
mem_event.process_id = mem_event_proto.process_id();
mem_event.thread_id = mem_event_proto.thread_id();
mem_event.increase_bytes = mem_event_proto.increase_bytes();
mem_event.place = mem_event_proto.place();
mem_event.current_allocated = mem_event_proto.current_allocated();
mem_event.current_reserved = mem_event_proto.current_reserved();
return new MemTraceEventNode(mem_event);
}

OperatorSupplementEventNode*
DeserializationReader::RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto& op_supplement_node_proto) {
const OperatorSupplementEventProto& op_supplement_event_proto =
op_supplement_node_proto.op_supplement_event();
OperatorSupplementEvent op_supplement_event;
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
auto input_shape_proto = op_supplement_event_proto.input_shapes();
for (int i = 0; i < input_shape_proto.key_size(); i++) {
auto input_shape_vec = input_shapes[input_shape_proto.key(i)];
auto shape_vectors_proto = input_shape_proto.shape_vecs(i);
for (int j = 0; j < shape_vectors_proto.shapes_size(); j++) {
auto shape_vector_proto = shape_vectors_proto.shapes(j);
std::vector<int64_t> shape;
for (int k = 0; k < shape_vector_proto.size_size(); k++) {
shape.push_back(shape_vector_proto.size(k));
}
input_shape_vec.push_back(shape);
}
}
op_supplement_event.input_shapes = input_shapes;
auto dtype_proto = op_supplement_event_proto.dtypes();
for (int i = 0; i < dtype_proto.key_size(); i++) {
auto dtype_vec = dtypes[dtype_proto.key(i)];
auto dtype_vec_proto = dtype_proto.dtype_vecs(i);
for (int j = 0; j < dtype_vec_proto.dtype_size(); j++) {
auto dtype_string = dtype_vec_proto.dtype(j);
dtype_vec.push_back(dtype_string);
}
}
op_supplement_event.dtypes = dtypes;
return new OperatorSupplementEventNode(op_supplement_event);
}

KernelEventInfo DeserializationReader::HandleKernelEventInfoProto(
const DeviceTraceEventProto& device_event_proto) {
const KernelEventInfoProto& kernel_info_proto =
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/profiler/dump/deserialization_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class DeserializationReader {
KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&);
MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&);
MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&);
MemTraceEventNode* RestoreMemTraceEventNode(const MemTraceEventNodeProto&);
OperatorSupplementEventNode* RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto&);
std::string filename_;
std::ifstream input_file_stream_;
NodeTreesProto* node_trees_proto_;
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/platform/profiler/dump/nodetree.proto
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ enum TracerEventTypeProto {
PythonOp = 13;
// Used to mark python level userdefined
PythonUserDefined = 14;
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15;
};

enum TracerMemEventTypeProto {
// Used to mark memory allocation
Allocate = 0;
// Used to mark memory free
Free = 1;
};

message KernelEventInfoProto {
Expand Down Expand Up @@ -121,6 +130,58 @@ message HostTraceEventProto {
required uint64 thread_id = 6;
}

message MemTraceEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// memory manipulation type
required TracerMemEventTypeProto type = 2;
// memory addr of allocation or free
required uint64 addr = 3;
// process id of the record
required uint64 process_id = 4;
// thread id of the record
required uint64 thread_id = 5;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
required int64 increase_bytes = 6;
// place
required string place = 7;
// current total allocated memory
required uint64 current_allocated = 8;
// current total reserved memory
required uint64 current_reserved = 9;
}

message OperatorSupplementEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// op type name
required string op_type = 2;
// process id of the record
required uint64 process_id = 3;
// thread id of the record
required uint64 thread_id = 4;
// input shapes
message input_shape_proto {
repeated string key = 1;
message shape_vector {
message shape { repeated uint64 size = 1; }
repeated shape shapes = 1;
}
repeated shape_vector shape_vecs = 2;
}
required input_shape_proto input_shapes = 5;
// dtypes
message dtype_proto {
repeated string key = 1;
message dtype_vector { repeated string dtype = 1; }
repeated dtype_vector dtype_vecs = 2;
}
required dtype_proto dtypes = 6;
// call stack
required string callstack = 7;
}

message CudaRuntimeTraceEventProto {
// record name
required string name = 1;
Expand Down Expand Up @@ -166,6 +227,12 @@ message DeviceTraceEventProto {
}
}

message OperatorSupplementEventNodeProto {
required OperatorSupplementEventProto op_supplement_event = 1;
}

message MemTraceEventNodeProto { required MemTraceEventProto mem_event = 1; }

message DeviceTraceEventNodeProto {
required DeviceTraceEventProto device_event = 1;
}
Expand All @@ -180,6 +247,9 @@ message HostTraceEventNodeProto {
required int64 parentid = 2;
required HostTraceEventProto host_trace_event = 3;
repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4;
// below is added in version 1.0.1
repeated MemTraceEventNodeProto mem_nodes = 5;
repeated OperatorSupplementEventNodeProto op_supplement_nodes = 6;
}

message ThreadNodeTreeProto {
Expand Down
78 changes: 77 additions & 1 deletion paddle/fluid/platform/profiler/dump/serialization_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace platform {

static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb";
static const char* version = "1.0.0";
static const char* version = "1.0.1";
static uint32_t span_indx = 0;

static std::string DefaultFileName() {
Expand Down Expand Up @@ -106,10 +106,33 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this); // fill detail information
}
}
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end(); ++memnode) {
MemTraceEventNodeProto* mem_node_proto =
current_host_trace_event_node_proto_->add_mem_nodes();
current_mem_trace_event_node_proto_ = mem_node_proto;
(*memnode)->LogMe(this);
}
}
}
}

void SerializationLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
MemTraceEventProto* mem_trace_event = new MemTraceEventProto();
mem_trace_event->set_timestamp_ns(mem_node.TimeStampNs());
mem_trace_event->set_type(
static_cast<TracerMemEventTypeProto>(mem_node.Type()));
mem_trace_event->set_addr(mem_node.Addr());
mem_trace_event->set_process_id(mem_node.ProcessId());
mem_trace_event->set_thread_id(mem_node.ThreadId());
mem_trace_event->set_increase_bytes(mem_node.IncreaseBytes());
mem_trace_event->set_place(mem_node.Place());
mem_trace_event->set_current_allocated(mem_node.CurrentAllocated());
mem_trace_event->set_current_reserved(mem_node.CurrentReserved());
current_mem_trace_event_node_proto_->set_allocated_mem_event(mem_trace_event);
}

void SerializationLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) {
HostTraceEventProto* host_trace_event = new HostTraceEventProto();
Expand All @@ -122,6 +145,59 @@ void SerializationLogger::LogHostTraceEventNode(
host_trace_event->set_thread_id(host_node.ThreadId());
current_host_trace_event_node_proto_->set_allocated_host_trace_event(
host_trace_event);
OperatorSupplementEventNode* op_supplement_event_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_event_node != nullptr) {
current_op_supplement_event_node_proto_ =
current_host_trace_event_node_proto_->add_op_supplement_nodes();
OperatorSupplementEventProto* op_supplement_event_proto =
new OperatorSupplementEventProto();
op_supplement_event_proto->set_op_type(op_supplement_event_node->Name());
op_supplement_event_proto->set_timestamp_ns(
op_supplement_event_node->TimeStampNs());
op_supplement_event_proto->set_process_id(
op_supplement_event_node->ProcessId());
op_supplement_event_proto->set_thread_id(
op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack());

OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes();
for (auto it = op_supplement_event_node->InputShapes().begin();
it != op_supplement_event_node->InputShapes().end(); it++) {
input_shape_proto->add_key(it->first);
OperatorSupplementEventProto::input_shape_proto::shape_vector*
shape_vectors_proto = input_shape_proto->add_shape_vecs();
auto shape_vectors = it->second;
for (auto shape_vecs_it = shape_vectors.begin();
shape_vecs_it != shape_vectors.end(); shape_vecs_it++) {
auto shape_vector = *shape_vecs_it;
OperatorSupplementEventProto::input_shape_proto::shape_vector::shape*
shape_proto = shape_vectors_proto->add_shapes();
for (auto shape_it = shape_vector.begin();
shape_it != shape_vector.end(); shape_it++) {
shape_proto->add_size(*shape_it);
}
}
}

OperatorSupplementEventProto::dtype_proto* dtype_proto =
op_supplement_event_proto->mutable_dtypes();
for (auto it = op_supplement_event_node->Dtypes().begin();
it != op_supplement_event_node->Dtypes().end(); it++) {
dtype_proto->add_key(it->first);
OperatorSupplementEventProto::dtype_proto::dtype_vector*
dtype_vector_proto = dtype_proto->add_dtype_vecs();
auto dtype_vector = it->second;
for (auto dtype_it = dtype_vector.begin(); dtype_it != dtype_vector.end();
dtype_it++) {
dtype_vector_proto->add_dtype(*dtype_it);
}
}
current_op_supplement_event_node_proto_->set_allocated_op_supplement_event(
op_supplement_event_proto);
}
}

void SerializationLogger::LogRuntimeTraceEventNode(
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/profiler/dump/serialization_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;

private:
void OpenFile();
Expand All @@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger {
HostTraceEventNodeProto* current_host_trace_event_node_proto_;
CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_;
DeviceTraceEventNodeProto* current_device_trace_event_node_proto_;
MemTraceEventNodeProto* current_mem_trace_event_node_proto_;
OperatorSupplementEventNodeProto* current_op_supplement_event_node_proto_;
};

} // namespace platform
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/platform/profiler/dump/test_serialization_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using paddle::platform::ProfilerResult;
using paddle::platform::RuntimeTraceEvent;
using paddle::platform::SerializationLogger;
using paddle::platform::TracerEventType;
using paddle::platform::TracerMemEventType;

TEST(SerializationLoggerTest, dump_case0) {
std::list<HostTraceEvent> host_events;
Expand All @@ -50,6 +51,19 @@ TEST(SerializationLoggerTest, dump_case0) {
std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10));
host_events.push_back(HostTraceEvent(
std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11));
mem_events.push_back(MemTraceEvent(11500, 0x1000,
TracerMemEventType::Allocate, 10, 10, 50,
"GPU:0", 50, 50));
mem_events.push_back(MemTraceEvent(11900, 0x1000, TracerMemEventType::Free,
10, 10, -50, "GPU:0", 0, 50));
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
input_shapes[std::string("X")].push_back(std::vector<int64_t>{1, 2, 3});
input_shapes[std::string("X")].push_back(std::vector<int64_t>{4, 5, 6, 7});
dtypes[std::string("X")].push_back(std::string("int8"));
dtypes[std::string("X")].push_back(std::string("float32"));
op_supplement_events.push_back(OperatorSupplementEvent(
11600, "op1", input_shapes, dtypes, "op1()", 10, 10));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000,
17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000,
Expand Down Expand Up @@ -91,6 +105,8 @@ TEST(SerializationLoggerTest, dump_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
Expand All @@ -100,6 +116,7 @@ TEST(SerializationLoggerTest, dump_case0) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}

TEST(SerializationLoggerTest, dump_case1) {
Expand Down Expand Up @@ -154,6 +171,7 @@ TEST(SerializationLoggerTest, dump_case1) {
}
}
tree.LogMe(&logger);
logger.LogMetaInfo(std::unordered_map<std::string, std::string>());
}

TEST(DeserializationReaderTest, restore_case0) {
Expand All @@ -173,6 +191,8 @@ TEST(DeserializationReaderTest, restore_case0) {
if ((*it)->Name() == "op1") {
EXPECT_EQ((*it)->GetChildren().size(), 0u);
EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u);
EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u);
EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr);
}
}
for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) {
Expand Down
Loading

0 comments on commit 360b838

Please sign in to comment.