diff --git a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc index 1449ad2ca9df8..dee86a8463d0f 100644 --- a/paddle/fluid/framework/new_executor/feed_fetch_utils.cc +++ b/paddle/fluid/framework/new_executor/feed_fetch_utils.cc @@ -15,6 +15,7 @@ #include #include +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/new_executor/feed_fetch_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -48,42 +49,202 @@ void SetColAttrForFeedFetchOps(std::shared_ptr program_desc, } } -void SplitFeedTensor(const std::vector& feed_names, - const int64_t micro_batch_num, - Scope* scope, - std::vector>* out) { - if (micro_batch_num < 2) return; - - out->resize(micro_batch_num); +void SplitFeedTensors(const std::vector& feed_names, + const int64_t micro_batch_num, + Scope* scope, + std::vector>* out) { + std::vector feed_tensors; for (size_t i = 0; i < feed_names.size(); ++i) { auto feed_name = feed_names[i]; auto feed_var = scope->GetVar(feed_name); + PADDLE_ENFORCE_NOT_NULL( + feed_var, + platform::errors::NotFound("Variable %s should not be nullptr.", + feed_names[i])); + feed_tensors.push_back(feed_var->Get()); + } + + out->resize(micro_batch_num); + if (micro_batch_num < 2) { + (*out)[0] = std::move(feed_tensors); + return; + } + + for (size_t i = 0; i < feed_tensors.size(); ++i) { + auto& feed_tensor = feed_tensors[i]; + int64_t numel_size = feed_tensor.dims()[0]; + PADDLE_ENFORCE_EQ(numel_size % micro_batch_num, + 0, + phi::errors::InvalidArgument( + "Split expects feed data (%s)'s dim[0] (%d) is " + "diviable by micro_batch_num (%d).", + feed_names[i], + numel_size, + micro_batch_num)); + int64_t split_size = numel_size / micro_batch_num; + VLOG(4) << "Split feed data:" << feed_names[i] << ", dims:(" + << feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num; + for (int64_t j = 0; j < micro_batch_num; ++j) { + (*out)[j].resize(i + 1); + (*out)[j][i].ShareDataWith( + feed_tensor.Slice(j * split_size, j * split_size + split_size)); + } + } +} + +void FetchTensors(const std::vector& job_fetch_names, + const std::vector& fetch_var_names, + const int64_t micro_batch_id, + Scope* scope, + FetchUnmergedList* fetch_list) { + PADDLE_ENFORCE_GT( + fetch_list->size(), + micro_batch_id, + phi::errors::Unavailable("The fetch list size (%lld) should be greater " + "than micro_batch_id (%lld)", + fetch_list->size(), + micro_batch_id)); + + fetch_list->at(micro_batch_id).resize(fetch_var_names.size()); + for (auto& var_name : job_fetch_names) { + int col = find(fetch_var_names.begin(), fetch_var_names.end(), var_name) - + fetch_var_names.begin(); + auto* var = scope->FindVar(var_name); + auto& src = var->Get(); + auto* dst = + &(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col])); + TensorCopy(src, platform::CPUPlace(), dst); + } +} + +void MergeFetchTensors(const FetchUnmergedList& fetch_list, + const int64_t micro_batch_num, + FetchList* out) { + if (fetch_list.size() == 0) return; + + PADDLE_ENFORCE_EQ( + fetch_list.size(), + micro_batch_num, + phi::errors::Unavailable("The fetch_list size (%lld) shoule be equal to " + "the micro_batch_num (%lld)", + fetch_list.size(), + micro_batch_num)); - if (feed_var->IsType()) { - phi::DenseTensor feed_tensor = feed_var->Get(); - int64_t numel_size = feed_tensor.dims()[0]; - PADDLE_ENFORCE_EQ(numel_size % micro_batch_num, - 0, - platform::errors::InvalidArgument( - "Split expects feed data (%s)'s dim[0] (%d) is " - "diviable by micro_batch_num (%d).", - feed_name, - numel_size, - micro_batch_num)); - int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num; - VLOG(4) << "Split feed data:" << feed_name << ", dims:(" - << feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num; - for (int64_t j = 0; j < micro_batch_num; ++j) { - (*out)[j].resize(i + 1); - (*out)[j][i].ShareDataWith( - feed_tensor.Slice(j * split_size, j * split_size + split_size)); + if (micro_batch_num < 2) { + *out = std::move(fetch_list[0]); + return; + } + + out->resize(fetch_list[0].size()); + for (size_t i = 0; i < fetch_list[0].size(); ++i) { + std::vector tensors_ptr; + for (auto micro_batch_id = 0; micro_batch_id < micro_batch_num; + ++micro_batch_id) { + tensors_ptr.push_back( + &PADDLE_GET_CONST(phi::DenseTensor, fetch_list[micro_batch_id][i])); + } + phi::DenseTensor merged_tensor; + MergeTensors(tensors_ptr, platform::CPUPlace(), &merged_tensor); + out->at(i) = std::move(merged_tensor); + } +} + +void MergeTensors(const std::vector& tensors, + const platform::Place dst_place, + phi::DenseTensor* target) { + PADDLE_ENFORCE_EQ( + tensors.empty(), + false, + phi::errors::InvalidArgument("The tensors to be merged are empty.")); + + DDim new_dim = tensors[0]->dims(); + proto::VarType::Type new_type = proto::VarType::FP32; + phi::DataLayout new_layout = tensors[0]->layout(); + for (auto* t : tensors) { + if (t->numel() && t->IsInitialized()) { + new_dim = t->dims(); + new_type = framework::TransToProtoVarType(t->dtype()); + new_layout = t->layout(); + break; + } + } + + auto rank = tensors[0]->dims().size(); + if (rank == 0) { + std::vector init_shape = {1}; + new_dim = new_dim.reshape(init_shape); + } + + for (size_t i = 1; i < tensors.size(); ++i) { + auto* t = tensors[i]; + if (t->numel() && t->IsInitialized()) { + PADDLE_ENFORCE_EQ( + new_type, + framework::TransToProtoVarType(t->dtype()), + phi::errors::InvalidArgument( + "phi::DenseTensor data type does not match, expected type is %s, " + "actual " + "type is %s.", + DataTypeToString(new_type), + DataTypeToString(framework::TransToProtoVarType(t->dtype())))); + PADDLE_ENFORCE_EQ( + new_layout, + t->layout(), + phi::errors::InvalidArgument( + "phi::DenseTensor layout does not match, expected layout is %s, " + "actual layout is %s.", + phi::DataLayoutToString(new_layout), + phi::DataLayoutToString(t->layout()))); + if (rank > 0) { + auto tensor_dims = t->dims(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), + new_dim.size(), + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + for (int j = 1; j < t->dims().size(); j++) { + PADDLE_ENFORCE_EQ( + tensor_dims[j], + new_dim[j], + phi::errors::InvalidArgument( + "DenseTensor.ddim[%d] should eaqual to %d, but is %d", + j, + new_dim[j], + tensor_dims[j])); + } + new_dim[0] += t->dims()[0]; + } else if (rank == 0) { + auto tensor_dims = t->dims(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), + 0, + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + PADDLE_ENFORCE_EQ(new_dim.size(), + 1, + phi::errors::InvalidArgument( + "dimensions of DenseTensor does not match")); + new_dim[0] += 1; } - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Type (%s) not support in SplitFeedTensor.", - ToTypeName(feed_var->Type()))); } } + + target->Resize(new_dim); + target->set_layout(new_layout); + target->mutable_data(dst_place, TransToPhiDataType(new_type)); + + int begin = 0; + for (auto* src : tensors) { + int src_dim = 1; + if (src->dims()[0] > 0) { + src_dim = src->dims()[0]; + } + int end = static_cast(begin + src_dim); + if (end == begin) { + continue; + } + auto dst = target->Slice(begin, end); + TensorCopy(*src, dst_place, &dst); + begin = end; + } } } // namespace framework diff --git a/paddle/fluid/framework/new_executor/feed_fetch_utils.h b/paddle/fluid/framework/new_executor/feed_fetch_utils.h index df5d650aa9eba..d1eff750d4a8f 100644 --- a/paddle/fluid/framework/new_executor/feed_fetch_utils.h +++ b/paddle/fluid/framework/new_executor/feed_fetch_utils.h @@ -27,10 +27,24 @@ void SetColAttrForFeedFetchOps(std::shared_ptr program_desc, const int64_t micro_batch_num, const int64_t micro_batch_id); -void SplitFeedTensor(const std::vector& feed_names, - const int64_t micro_batch_num, - Scope* scope, - std::vector>* out); +void SplitFeedTensors(const std::vector& feed_names, + const int64_t micro_batch_num, + Scope* scope, + std::vector>* out); + +void FetchTensors(const std::vector& job_fetch_names, + const std::vector& fetch_var_names, + const int64_t micro_batch_id, + Scope* scope, + FetchUnmergedList* fetch_list); + +void MergeFetchTensors(const FetchUnmergedList& fetch_list, + const int64_t micro_batch_num, + FetchList* out); + +void MergeTensors(const std::vector& tensors, + const platform::Place dst_place, + phi::DenseTensor* target); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/job.h b/paddle/fluid/framework/new_executor/interpreter/job.h index 493063f9e1516..952702d6e2f0a 100644 --- a/paddle/fluid/framework/new_executor/interpreter/job.h +++ b/paddle/fluid/framework/new_executor/interpreter/job.h @@ -31,27 +31,10 @@ class Job final { const std::string& Type() const { return type_; } - int ColAttrForFetchOp(int fetch_op_id) const { - return fetch_op_id_to_col_attr_.at(fetch_op_id); - } - int64_t MicroBatchId() const { return micro_batch_id_; } std::set SkipGcVars() const { return skip_gc_vars_; } - std::vector AllFetchOpIds() const { - std::vector fetch_op_ids; - fetch_op_ids.reserve(fetch_op_id_to_col_attr_.size()); - for (auto& item : fetch_op_id_to_col_attr_) { - fetch_op_ids.push_back(item.first); - } - return fetch_op_ids; - } - - void SetColAttrForFetchOp(int fetch_op_id, int col_attr) { - fetch_op_id_to_col_attr_[fetch_op_id] = col_attr; - } - void SetMicroBatchId(int64_t micro_batch_id) { PADDLE_ENFORCE_GE( micro_batch_id, @@ -71,11 +54,17 @@ class Job final { skip_gc_vars_ = skip_gc_vars; } + void SetFetchVarName(const std::string& fetch_var_name) { + fetch_var_names_.push_back(fetch_var_name); + } + + std::vector FetchVarNames() { return fetch_var_names_; } + private: const std::string type_; int64_t micro_batch_id_; - std::unordered_map fetch_op_id_to_col_attr_; std::set skip_gc_vars_; + std::vector fetch_var_names_; }; } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index 369216e0078c4..747f945050042 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -67,7 +67,8 @@ class InterpreterBaseImpl { virtual ~InterpreterBaseImpl() = default; virtual paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) = 0; + const std::vector& feed_tensors, + bool need_fetch = true) = 0; virtual paddle::framework::FetchList Run( const std::vector& feed_names, bool need_fetch = true) = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 1a1ee56e17525..d7efd510535e8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -65,8 +65,9 @@ InterpreterCore::~InterpreterCore() { FetchList InterpreterCore::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { - return impl_->Run(feed_names, feed_tensors); + const std::vector& feed_tensors, + bool need_fetch) { + return impl_->Run(feed_names, feed_tensors, need_fetch); } FetchList InterpreterCore::Run(const std::vector& feed_names, diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index d21bd9e1fc378..022bc0c06f5b2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -47,7 +47,8 @@ class InterpreterCore { paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors); + const std::vector& feed_tensors, + bool need_fetch = true); paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true); diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index ee05bf9578998..fbee372871975 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -1023,7 +1023,8 @@ void PirInterpreter::ConstructEventForJitInput() { paddle::framework::FetchList PirInterpreter::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { + const std::vector& feed_tensors, + bool need_fetch) { auto FeedInput = [&] { VLOG(4) << "Feed inputs"; for (size_t i = 0; i < feed_names.size(); ++i) { @@ -1100,10 +1101,12 @@ paddle::framework::FetchList PirInterpreter::Run( if (FLAGS_enable_new_ir_in_executor) { framework::FetchList fetch_res; - for (auto& var_name : fetch_var_names_) { - auto* var = inner_scope->FindVar(var_name); - VLOG(0) << "fetch " << var_name << "[" << var << "]"; - fetch_res.push_back(var->Get()); + if (need_fetch) { + for (auto& var_name : fetch_var_names_) { + auto* var = inner_scope->FindVar(var_name); + VLOG(4) << "fetch " << var_name << "[" << var << "]"; + fetch_res.push_back(var->Get()); + } } VLOG(4) << "get fetch list size: " << fetch_res.size(); @@ -1192,7 +1195,7 @@ FetchList PirInterpreter::Run(const std::vector& feed_names, if (need_fetch) { for (auto& var_name : fetch_var_names_) { auto* var = inner_scope->FindVar(var_name); - VLOG(0) << "fetch " << var_name << "[" << var << "]"; + VLOG(4) << "fetch " << var_name << "[" << var << "]"; fetch_res.push_back(var->Get()); } } diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h index 80052308e8743..0399c3c04b01e 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -51,7 +51,8 @@ class PirInterpreter : public InterpreterBaseImpl { paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) override; + const std::vector& feed_tensors, + bool need_fetch = true) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true) override; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 2df562e7bef18..1b287a11232a1 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -203,7 +203,8 @@ void ProgramInterpreter::Build( FetchList ProgramInterpreter::Run( const std::vector& feed_names, - const std::vector& feed_tensors) { + const std::vector& feed_tensors, + bool need_fetch) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -226,7 +227,7 @@ FetchList ProgramInterpreter::Run( Scope* inner_scope = HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var) { + if (fetch_var && need_fetch) { auto fetch_list = std::move(*fetch_var->GetMutable()); #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index bef6385c211fb..9c4b8f9bf1c9b 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -43,7 +43,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { paddle::framework::FetchList Run( const std::vector& feed_names, - const std::vector& feed_tensors) override; + const std::vector& feed_tensors, + bool need_fetch = true) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true) override; diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 1aabc7a89e355..41ebdea7cc8c2 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -80,8 +80,6 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, // TODO(phlrain) we only support cpu for now if (FLAGS_enable_new_ir_in_executor) { - auto inner_scope = - micro_batch_num == 1 ? scope : micro_batch_scopes_[micro_batch_id]; std::shared_ptr<::pir::Program> base_program = ir_program; auto block = base_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { @@ -102,6 +100,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, .dyn_cast() .AsString() + "@fetch"; + job->SetFetchVarName(fetch_var_names_[index]); } } auto kernel_program = @@ -117,9 +116,9 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, interpretercores_.emplace_back( std::make_shared(place_, - fetch_var_names_, + job->FetchVarNames(), shared_program->block(), - inner_scope, + micro_batch_scopes_[micro_batch_id], execution_config)); } else { interpretercores_.emplace_back( @@ -180,9 +179,10 @@ paddle::framework::FetchList StandaloneExecutor::Run( std::vector> splited_feeds; if (FLAGS_enable_new_ir_in_executor) { - SplitFeedTensor(feed_names, plan_.MicroBatchNum(), scope_, &splited_feeds); + SplitFeedTensors(feed_names, plan_.MicroBatchNum(), scope_, &splited_feeds); } + fetch_list_.resize(plan_.MicroBatchNum()); for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) { const auto& job = jobs[job_idx]; const std::string& job_type = job->Type(); @@ -201,9 +201,16 @@ paddle::framework::FetchList StandaloneExecutor::Run( interpretercores_[type_to_first_id[job_type]]); } - if (FLAGS_enable_new_ir_in_executor && splited_feeds.size() > 0) { + if (FLAGS_enable_new_ir_in_executor) { interpretercores_[job_idx]->Run(feed_names, - splited_feeds[job->MicroBatchId()]); + splited_feeds[job->MicroBatchId()], + /*need_fetch = */ false); + + FetchTensors(job->FetchVarNames(), + fetch_var_names_, + job->MicroBatchId(), + micro_batch_scopes_[job->MicroBatchId()], + &fetch_list_); } else { if (jobs.size() > 1 && job_type != "forward") { const std::vector tmp_feed_names = {}; @@ -218,11 +225,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( // return Fetch Tensors if (FLAGS_enable_new_ir_in_executor) { framework::FetchList fetch_res; - for (auto& var_name : fetch_var_names_) { - auto* var = scope_->FindVar(var_name); - fetch_res.push_back(var->Get()); - } - + MergeFetchTensors(fetch_list_, plan_.MicroBatchNum(), &fetch_res); return fetch_res; } else { auto* fetch_var = scope_->FindVar(interpreter::kFetchVarName); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index 50bb09c9353a0..8feef6e5b2f91 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -51,6 +51,7 @@ class StandaloneExecutor { std::vector micro_batch_scopes_; std::vector fetch_var_names_; + FetchUnmergedList fetch_list_; std::vector>> vec_force_events_to_wait_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7d676ef6e189c..148468e756d74 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2026,8 +2026,6 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init(), py::arg("type")) .def("micro_batch_id", &framework::interpreter::Job::MicroBatchId) .def("type", &framework::interpreter::Job::Type) - .def("set_col_attr_for_fetch_op", - &framework::interpreter::Job::SetColAttrForFetchOp) .def("set_micro_batch_id", &framework::interpreter::Job::SetMicroBatchId) .def("set_skip_gc_vars", &framework::interpreter::Job::SetSkipGcVars); diff --git a/test/auto_parallel/gpt_with_newir.py b/test/auto_parallel/gpt_with_newir.py index 04e0fe4f669b8..3f455d156be37 100644 --- a/test/auto_parallel/gpt_with_newir.py +++ b/test/auto_parallel/gpt_with_newir.py @@ -207,7 +207,7 @@ def test_pp_1f1b(self): if paddle.distributed.get_rank() == 1: self.check_results( - out_1f1b_prog.history["loss"][0][1], + out_1f1b_prog.history["loss"][0], out_1f1b_ir.history["loss"][0], ) @@ -232,7 +232,7 @@ def test_pp_fthenb(self): ) if paddle.distributed.get_rank() == 1: self.check_results( - out_fthenb_prog.history["loss"][0][1], + out_fthenb_prog.history["loss"][0], out_fthenb_ir.history["loss"][0], ) diff --git a/test/legacy_test/test_allclose_op.py b/test/legacy_test/test_allclose_op.py index 474f3edb3063f..cb76671284e2c 100644 --- a/test/legacy_test/test_allclose_op.py +++ b/test/legacy_test/test_allclose_op.py @@ -181,7 +181,7 @@ def test_fp16(self): y_data = np.random.rand(10, 10).astype('float16') with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') - y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + y = paddle.static.data(shape=[10, 10], name='y', dtype='float16') out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08) if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) diff --git a/test/standalone_executor/test_standalone_custom_event.py b/test/standalone_executor/test_standalone_custom_event.py index c695ada310230..18595eefe5a42 100644 --- a/test/standalone_executor/test_standalone_custom_event.py +++ b/test/standalone_executor/test_standalone_custom_event.py @@ -139,13 +139,6 @@ def create_standalone_exe(self, main_progs, startup_progs, fetch_list): # create jobs for program_id in range(prog_num): job = core.Job(f"prog_{program_id}") - # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch - if program_id == prog_num - 1: - for i in range(fetch_op_num): - job.set_col_attr_for_fetch_op( - fetch_op_indics[i], - i * micro_batch_num + micro_batch_id, - ) job_list.append(job) job_types = [] diff --git a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py index 4841aae6b5dd3..6222431d5cce9 100644 --- a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py +++ b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py @@ -185,14 +185,6 @@ def run_train(self, split=False, micro_batch_num=1): for program_id in range(program_num): job = Job(f"P{program_id}") job.set_micro_batch_id(micro_batch_id) - # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch - if program_id == program_num - 1: - fetch_op_id_to_col_attr = {} - for i in range(fetch_op_num): - job.set_col_attr_for_fetch_op( - fetch_op_indics[i], - i * micro_batch_num + micro_batch_id, - ) job_list.append(job) job_types = []