From eda87f1dc90129e10cca0035e27d0d03494de169 Mon Sep 17 00:00:00 2001 From: yelrose <270018958@qq.com> Date: Wed, 11 Jan 2023 09:53:55 +0800 Subject: [PATCH 1/2] make infer & train same logic --- paddle/fluid/framework/data_feed.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 7451a038493c2..3dff7d71616c1 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -2107,9 +2107,9 @@ int GraphDataGenerator::FillInferBuf() { size_t device_key_size = h_device_keys_len_[infer_cursor]; total_row_ = - (global_infer_node_type_start[infer_cursor] + infer_table_cap_ <= + (global_infer_node_type_start[infer_cursor] + buf_size_ <= device_key_size) - ? infer_table_cap_ + ? buf_size_ : device_key_size - global_infer_node_type_start[infer_cursor]; uint64_t *d_type_keys = From f414e9f60b564fe15678e66b6314fd2fad105c3b Mon Sep 17 00:00:00 2001 From: yelrose <270018958@qq.com> Date: Thu, 12 Jan 2023 12:01:47 +0800 Subject: [PATCH 2/2] make infer & train same logic --- paddle/fluid/framework/trainer.cc | 12 +++++++++++- paddle/fluid/framework/trainer.h | 1 + paddle/fluid/framework/trainer_desc.proto | 1 + python/paddle/fluid/trainer_desc.py | 3 +++ python/paddle/fluid/trainer_factory.py | 2 ++ 5 files changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/trainer.cc b/paddle/fluid/framework/trainer.cc index 2d8e567b65a7d..9b7d3c73aebf3 100644 --- a/paddle/fluid/framework/trainer.cc +++ b/paddle/fluid/framework/trainer.cc @@ -25,6 +25,8 @@ void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) { dump_fields_path_ = desc.dump_fields_path(); need_dump_field_ = false; need_dump_param_ = false; + dump_fields_mode_ = desc.dump_fields_mode(); + if (dump_fields_path_ == "") { VLOG(2) << "dump_fields_path_ is empty"; return; @@ -58,7 +60,15 @@ void TrainerBase::DumpWork(int tid) { int err_no = 0; // GetDumpPath is implemented in each Trainer std::string path = GetDumpPath(tid); - std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); + std::shared_ptr fp; + if (dump_fields_mode_ == "a") { + VLOG(3) << "dump field mode append"; + fp = fs_open_append_write(path, &err_no, dump_converter_); + } + else { + VLOG(3) << "dump field mode overwrite"; + fp = fs_open_write(path, &err_no, dump_converter_); + } while (1) { std::string out_str; if (!queue_->Get(out_str)) { diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index a1c7ecffa443c..1661ec6254089 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -94,6 +94,7 @@ class TrainerBase { std::string dump_converter_; std::vector dump_param_; std::vector dump_fields_; + std::string dump_fields_mode_; int dump_thread_num_; std::vector dump_thread_; std::shared_ptr> queue_; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index daded21ec62d9..1bd6d14379b48 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -69,6 +69,7 @@ message TrainerDesc { // add for gpu optional string fleet_desc = 37; optional bool is_dump_in_simple_mode = 38 [ default = false ]; + optional string dump_fields_mode = 39 [ default = "w" ]; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index c4c17c7095aa0..ebd3060a67f29 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -184,6 +184,9 @@ def _set_dump_param(self, dump_param): for param in dump_param: self.proto_desc.dump_param.append(param) + def _set_dump_fields_mode(self, mode): + self.proto_desc.dump_fields_mode = mode + def _set_worker_places(self, worker_places): for place in worker_places: self.proto_desc.worker_places.append(place) diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 0fa674ed3f015..37a6cdddbba98 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -73,6 +73,8 @@ def _create_trainer(self, opt_info=None): if opt_info.get("dump_fields_path") is not None and len( opt_info.get("dump_fields_path")) != 0: trainer._set_dump_fields_path(opt_info["dump_fields_path"]) + if opt_info.get("dump_fields_mode") is not None: + trainer._set_dump_fields_mode(opt_info["dump_fields_mode"]) if opt_info.get( "user_define_dump_filename") is not None and len( opt_info.get("user_define_dump_filename")) != 0: