Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make infer & train same logic #196

Merged
merged 2 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2107,9 +2107,9 @@ int GraphDataGenerator::FillInferBuf() {

size_t device_key_size = h_device_keys_len_[infer_cursor];
total_row_ =
(global_infer_node_type_start[infer_cursor] + infer_table_cap_ <=
(global_infer_node_type_start[infer_cursor] + buf_size_ <=
Yelrose marked this conversation as resolved.
Show resolved Hide resolved
device_key_size)
? infer_table_cap_
? buf_size_
: device_key_size - global_infer_node_type_start[infer_cursor];

uint64_t *d_type_keys =
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) {
dump_fields_path_ = desc.dump_fields_path();
need_dump_field_ = false;
need_dump_param_ = false;
dump_fields_mode_ = desc.dump_fields_mode();

if (dump_fields_path_ == "") {
VLOG(2) << "dump_fields_path_ is empty";
return;
Expand Down Expand Up @@ -58,7 +60,15 @@ void TrainerBase::DumpWork(int tid) {
int err_no = 0;
// GetDumpPath is implemented in each Trainer
std::string path = GetDumpPath(tid);
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
std::shared_ptr<FILE> fp;
if (dump_fields_mode_ == "a") {
VLOG(3) << "dump field mode append";
fp = fs_open_append_write(path, &err_no, dump_converter_);
}
else {
VLOG(3) << "dump field mode overwrite";
fp = fs_open_write(path, &err_no, dump_converter_);
}
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TrainerBase {
std::string dump_converter_;
std::vector<std::string> dump_param_;
std::vector<std::string> dump_fields_;
std::string dump_fields_mode_;
int dump_thread_num_;
std::vector<std::thread> dump_thread_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/trainer_desc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ message TrainerDesc {
// add for gpu
optional string fleet_desc = 37;
optional bool is_dump_in_simple_mode = 38 [ default = false ];
optional string dump_fields_mode = 39 [ default = "w" ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/trainer_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)

def _set_dump_fields_mode(self, mode):
self.proto_desc.dump_fields_mode = mode

def _set_worker_places(self, worker_places):
for place in worker_places:
self.proto_desc.worker_places.append(place)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _create_trainer(self, opt_info=None):
if opt_info.get("dump_fields_path") is not None and len(
opt_info.get("dump_fields_path")) != 0:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("dump_fields_mode") is not None:
trainer._set_dump_fields_mode(opt_info["dump_fields_mode"])
if opt_info.get(
"user_define_dump_filename") is not None and len(
opt_info.get("user_define_dump_filename")) != 0:
Expand Down