Skip to content

Commit

Permalink
add serialization for new field in event node (PaddlePaddle#43405)
Browse files Browse the repository at this point in the history
* add serialization for new field in event node

* fix a bug
  • Loading branch information
rainyfly committed Jun 23, 2022
1 parent 4730147 commit e09341b
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 15 deletions.
91 changes: 86 additions & 5 deletions paddle/fluid/platform/profiler/dump/deserialization_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
ExtraInfo extrainfo;
for (auto indx = 0; indx < node_trees_proto_->extra_info_size(); indx++) {
ExtraInfoMap extra_info_map = node_trees_proto_->extra_info(indx);
extrainfo.AddExtraInfo(extra_info_map.key(), std::string("%s"),
extrainfo.AddExtraInfo(extra_info_map.key(),
std::string("%s"),
extra_info_map.value().c_str());
}
// restore NodeTrees
Expand Down Expand Up @@ -90,6 +91,26 @@ std::unique_ptr<ProfilerResult> DeserializationReader::Parse() {
device_node); // insert into runtime_node
}
}
// handle mem node
for (int mem_node_index = 0;
mem_node_index < host_node_proto.mem_nodes_size();
mem_node_index++) {
const MemTraceEventNodeProto& mem_node_proto =
host_node_proto.mem_nodes(mem_node_index);
MemTraceEventNode* mem_node = RestoreMemTraceEventNode(mem_node_proto);
host_node->AddMemNode(mem_node);
}
// handle op supplement node
for (int op_supplement_node_index = 0;
op_supplement_node_index <
host_node_proto.op_supplement_nodes_size();
op_supplement_node_index++) {
const OperatorSupplementEventNodeProto& op_supplement_node_proto =
host_node_proto.op_supplement_nodes(op_supplement_node_index);
OperatorSupplementEventNode* op_supplement_node =
RestoreOperatorSupplementEventNode(op_supplement_node_proto);
host_node->SetOperatorSupplementNode(op_supplement_node);
}
}
// restore parent-child relationship
for (auto it = child_parent_map.begin(); it != child_parent_map.end();
Expand Down Expand Up @@ -174,6 +195,62 @@ HostTraceEventNode* DeserializationReader::RestoreHostTraceEventNode(
return new HostTraceEventNode(host_event);
}

MemTraceEventNode* DeserializationReader::RestoreMemTraceEventNode(
const MemTraceEventNodeProto& mem_node_proto) {
const MemTraceEventProto& mem_event_proto = mem_node_proto.mem_event();
MemTraceEvent mem_event;
mem_event.timestamp_ns = mem_event_proto.timestamp_ns();
mem_event.addr = mem_event_proto.addr();
mem_event.type = static_cast<TracerMemEventType>(mem_event_proto.type());
mem_event.process_id = mem_event_proto.process_id();
mem_event.thread_id = mem_event_proto.thread_id();
mem_event.increase_bytes = mem_event_proto.increase_bytes();
mem_event.place = mem_event_proto.place();
mem_event.current_allocated = mem_event_proto.current_allocated();
mem_event.current_reserved = mem_event_proto.current_reserved();
return new MemTraceEventNode(mem_event);
}

OperatorSupplementEventNode*
DeserializationReader::RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto& op_supplement_node_proto) {
const OperatorSupplementEventProto& op_supplement_event_proto =
op_supplement_node_proto.op_supplement_event();
OperatorSupplementEvent op_supplement_event;
op_supplement_event.timestamp_ns = op_supplement_event_proto.timestamp_ns();
op_supplement_event.op_type = op_supplement_event_proto.op_type();
op_supplement_event.callstack = op_supplement_event_proto.callstack();
op_supplement_event.process_id = op_supplement_event_proto.process_id();
op_supplement_event.thread_id = op_supplement_event_proto.thread_id();
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> dtypes;
auto input_shape_proto = op_supplement_event_proto.input_shapes();
for (int i = 0; i < input_shape_proto.key_size(); i++) {
auto input_shape_vec = input_shapes[input_shape_proto.key(i)];
auto shape_vectors_proto = input_shape_proto.shape_vecs(i);
for (int j = 0; j < shape_vectors_proto.shapes_size(); j++) {
auto shape_vector_proto = shape_vectors_proto.shapes(j);
std::vector<int64_t> shape;
for (int k = 0; k < shape_vector_proto.size_size(); k++) {
shape.push_back(shape_vector_proto.size(k));
}
input_shape_vec.push_back(shape);
}
}
op_supplement_event.input_shapes = input_shapes;
auto dtype_proto = op_supplement_event_proto.dtypes();
for (int i = 0; i < dtype_proto.key_size(); i++) {
auto dtype_vec = dtypes[dtype_proto.key(i)];
auto dtype_vec_proto = dtype_proto.dtype_vecs(i);
for (int j = 0; j < dtype_vec_proto.dtype_size(); j++) {
auto dtype_string = dtype_vec_proto.dtype(j);
dtype_vec.push_back(dtype_string);
}
}
op_supplement_event.dtypes = dtypes;
return new OperatorSupplementEventNode(op_supplement_event);
}

KernelEventInfo DeserializationReader::HandleKernelEventInfoProto(
const DeviceTraceEventProto& device_event_proto) {
const KernelEventInfoProto& kernel_info_proto =
Expand Down Expand Up @@ -203,11 +280,14 @@ MemcpyEventInfo DeserializationReader::HandleMemcpyEventInfoProto(
device_event_proto.memcpy_info();
MemcpyEventInfo memcpy_info;
memcpy_info.num_bytes = memcpy_info_proto.num_bytes();
std::strncpy(memcpy_info.copy_kind, memcpy_info_proto.copy_kind().c_str(),
std::strncpy(memcpy_info.copy_kind,
memcpy_info_proto.copy_kind().c_str(),
kMemKindMaxLen - 1);
std::strncpy(memcpy_info.src_kind, memcpy_info_proto.src_kind().c_str(),
std::strncpy(memcpy_info.src_kind,
memcpy_info_proto.src_kind().c_str(),
kMemKindMaxLen - 1);
std::strncpy(memcpy_info.dst_kind, memcpy_info_proto.dst_kind().c_str(),
std::strncpy(memcpy_info.dst_kind,
memcpy_info_proto.dst_kind().c_str(),
kMemKindMaxLen - 1);
return memcpy_info;
}
Expand All @@ -218,7 +298,8 @@ MemsetEventInfo DeserializationReader::HandleMemsetEventInfoProto(
device_event_proto.memset_info();
MemsetEventInfo memset_info;
memset_info.num_bytes = memset_info_proto.num_bytes();
std::strncpy(memset_info.memory_kind, memset_info_proto.memory_kind().c_str(),
std::strncpy(memset_info.memory_kind,
memset_info_proto.memory_kind().c_str(),
kMemKindMaxLen - 1);
memset_info.value = memset_info_proto.value();
return memset_info;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/profiler/dump/deserialization_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class DeserializationReader {
KernelEventInfo HandleKernelEventInfoProto(const DeviceTraceEventProto&);
MemcpyEventInfo HandleMemcpyEventInfoProto(const DeviceTraceEventProto&);
MemsetEventInfo HandleMemsetEventInfoProto(const DeviceTraceEventProto&);
MemTraceEventNode* RestoreMemTraceEventNode(const MemTraceEventNodeProto&);
OperatorSupplementEventNode* RestoreOperatorSupplementEventNode(
const OperatorSupplementEventNodeProto&);
std::string filename_;
std::ifstream input_file_stream_;
NodeTreesProto* node_trees_proto_;
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/platform/profiler/dump/nodetree.proto
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ enum TracerEventTypeProto {
PythonOp = 13;
// Used to mark python level userdefined
PythonUserDefined = 14;
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15;
};

enum TracerMemEventTypeProto {
// Used to mark memory allocation
Allocate = 0;
// Used to mark memory free
Free = 1;
};

message KernelEventInfoProto {
Expand Down Expand Up @@ -121,6 +130,58 @@ message HostTraceEventProto {
required uint64 thread_id = 6;
}

message MemTraceEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// memory manipulation type
required TracerMemEventTypeProto type = 2;
// memory addr of allocation or free
required uint64 addr = 3;
// process id of the record
required uint64 process_id = 4;
// thread id of the record
required uint64 thread_id = 5;
// increase bytes after this manipulation, allocation for sign +, free for
// sign -
required int64 increase_bytes = 6;
// place
required string place = 7;
// current total allocated memory
required uint64 current_allocated = 8;
// current total reserved memory
required uint64 current_reserved = 9;
}

message OperatorSupplementEventProto {
// timestamp of the record
required uint64 timestamp_ns = 1;
// op type name
required string op_type = 2;
// process id of the record
required uint64 process_id = 3;
// thread id of the record
required uint64 thread_id = 4;
// input shapes
message input_shape_proto {
repeated string key = 1;
message shape_vector {
message shape { repeated uint64 size = 1; }
repeated shape shapes = 1;
}
repeated shape_vector shape_vecs = 2;
}
required input_shape_proto input_shapes = 5;
// dtypes
message dtype_proto {
repeated string key = 1;
message dtype_vector { repeated string dtype = 1; }
repeated dtype_vector dtype_vecs = 2;
}
required dtype_proto dtypes = 6;
// call stack
required string callstack = 7;
}

message CudaRuntimeTraceEventProto {
// record name
required string name = 1;
Expand Down Expand Up @@ -166,6 +227,12 @@ message DeviceTraceEventProto {
}
}

message OperatorSupplementEventNodeProto {
required OperatorSupplementEventProto op_supplement_event = 1;
}

message MemTraceEventNodeProto { required MemTraceEventProto mem_event = 1; }

message DeviceTraceEventNodeProto {
required DeviceTraceEventProto device_event = 1;
}
Expand All @@ -180,6 +247,9 @@ message HostTraceEventNodeProto {
required int64 parentid = 2;
required HostTraceEventProto host_trace_event = 3;
repeated CudaRuntimeTraceEventNodeProto runtime_nodes = 4;
// below is added in version 1.0.1
repeated MemTraceEventNodeProto mem_nodes = 5;
repeated OperatorSupplementEventNodeProto op_supplement_nodes = 6;
}

message ThreadNodeTreeProto {
Expand Down
99 changes: 91 additions & 8 deletions paddle/fluid/platform/profiler/dump/serialization_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ namespace paddle {
namespace platform {

static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.pb";
static const char* version = "1.0.0";
static const char* version = "1.0.1";
static uint32_t span_indx = 0;

static std::string DefaultFileName() {
auto pid = GetProcessId();
return string_format(std::string(kDefaultFilename), pid,
GetStringFormatLocalTime().c_str());
return string_format(
std::string(kDefaultFilename), pid, GetStringFormatLocalTime().c_str());
}

void SerializationLogger::OpenFile() {
output_file_stream_.open(filename_, std::ofstream::out |
std::ofstream::trunc |
std::ofstream::binary);
output_file_stream_.open(
filename_,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
if (!output_file_stream_) {
LOG(WARNING) << "Unable to open file for writing profiling data."
<< std::endl;
Expand All @@ -50,7 +50,8 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
thread2host_event_nodes = node_trees.Traverse(true);

for (auto it = thread2host_event_nodes.begin();
it != thread2host_event_nodes.end(); ++it) {
it != thread2host_event_nodes.end();
++it) {
// 1. order every node an index, every node a parent
std::map<HostTraceEventNode*, int64_t> node_index_map;
std::map<HostTraceEventNode*, int64_t> node_parent_map;
Expand All @@ -64,7 +65,8 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
for (auto hostnode = it->second.begin(); hostnode != it->second.end();
++hostnode) {
for (auto childnode = (*hostnode)->GetChildren().begin();
childnode != (*hostnode)->GetChildren().end(); ++childnode) {
childnode != (*hostnode)->GetChildren().end();
++childnode) {
node_parent_map[(*childnode)] =
node_index_map[(*hostnode)]; // mark each node's parent
}
Expand Down Expand Up @@ -106,10 +108,34 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this); // fill detail information
}
}
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end();
++memnode) {
MemTraceEventNodeProto* mem_node_proto =
current_host_trace_event_node_proto_->add_mem_nodes();
current_mem_trace_event_node_proto_ = mem_node_proto;
(*memnode)->LogMe(this);
}
}
}
}

void SerializationLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
MemTraceEventProto* mem_trace_event = new MemTraceEventProto();
mem_trace_event->set_timestamp_ns(mem_node.TimeStampNs());
mem_trace_event->set_type(
static_cast<TracerMemEventTypeProto>(mem_node.Type()));
mem_trace_event->set_addr(mem_node.Addr());
mem_trace_event->set_process_id(mem_node.ProcessId());
mem_trace_event->set_thread_id(mem_node.ThreadId());
mem_trace_event->set_increase_bytes(mem_node.IncreaseBytes());
mem_trace_event->set_place(mem_node.Place());
mem_trace_event->set_current_allocated(mem_node.CurrentAllocated());
mem_trace_event->set_current_reserved(mem_node.CurrentReserved());
current_mem_trace_event_node_proto_->set_allocated_mem_event(mem_trace_event);
}

void SerializationLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) {
HostTraceEventProto* host_trace_event = new HostTraceEventProto();
Expand All @@ -122,6 +148,63 @@ void SerializationLogger::LogHostTraceEventNode(
host_trace_event->set_thread_id(host_node.ThreadId());
current_host_trace_event_node_proto_->set_allocated_host_trace_event(
host_trace_event);
OperatorSupplementEventNode* op_supplement_event_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_event_node != nullptr) {
current_op_supplement_event_node_proto_ =
current_host_trace_event_node_proto_->add_op_supplement_nodes();
OperatorSupplementEventProto* op_supplement_event_proto =
new OperatorSupplementEventProto();
op_supplement_event_proto->set_op_type(op_supplement_event_node->Name());
op_supplement_event_proto->set_timestamp_ns(
op_supplement_event_node->TimeStampNs());
op_supplement_event_proto->set_process_id(
op_supplement_event_node->ProcessId());
op_supplement_event_proto->set_thread_id(
op_supplement_event_node->ThreadId());
op_supplement_event_proto->set_callstack(
op_supplement_event_node->CallStack());

OperatorSupplementEventProto::input_shape_proto* input_shape_proto =
op_supplement_event_proto->mutable_input_shapes();
for (auto it = op_supplement_event_node->InputShapes().begin();
it != op_supplement_event_node->InputShapes().end();
it++) {
input_shape_proto->add_key(it->first);
OperatorSupplementEventProto::input_shape_proto::shape_vector*
shape_vectors_proto = input_shape_proto->add_shape_vecs();
auto shape_vectors = it->second;
for (auto shape_vecs_it = shape_vectors.begin();
shape_vecs_it != shape_vectors.end();
shape_vecs_it++) {
auto shape_vector = *shape_vecs_it;
OperatorSupplementEventProto::input_shape_proto::shape_vector::shape*
shape_proto = shape_vectors_proto->add_shapes();
for (auto shape_it = shape_vector.begin();
shape_it != shape_vector.end();
shape_it++) {
shape_proto->add_size(*shape_it);
}
}
}

OperatorSupplementEventProto::dtype_proto* dtype_proto =
op_supplement_event_proto->mutable_dtypes();
for (auto it = op_supplement_event_node->Dtypes().begin();
it != op_supplement_event_node->Dtypes().end();
it++) {
dtype_proto->add_key(it->first);
OperatorSupplementEventProto::dtype_proto::dtype_vector*
dtype_vector_proto = dtype_proto->add_dtype_vecs();
auto dtype_vector = it->second;
for (auto dtype_it = dtype_vector.begin(); dtype_it != dtype_vector.end();
dtype_it++) {
dtype_vector_proto->add_dtype(*dtype_it);
}
}
current_op_supplement_event_node_proto_->set_allocated_op_supplement_event(
op_supplement_event_proto);
}
}

void SerializationLogger::LogRuntimeTraceEventNode(
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/profiler/dump/serialization_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SerializationLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;

private:
void OpenFile();
Expand All @@ -48,6 +49,8 @@ class SerializationLogger : public BaseLogger {
HostTraceEventNodeProto* current_host_trace_event_node_proto_;
CudaRuntimeTraceEventNodeProto* current_runtime_trace_event_node_proto_;
DeviceTraceEventNodeProto* current_device_trace_event_node_proto_;
MemTraceEventNodeProto* current_mem_trace_event_node_proto_;
OperatorSupplementEventNodeProto* current_op_supplement_event_node_proto_;
};

} // namespace platform
Expand Down
Loading

0 comments on commit e09341b

Please sign in to comment.