diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index d5207398adca9..7c18f9288c5e7 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -17,6 +17,16 @@ #include #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +// The difference between "sequential_run" and "serial_run": +// "sequential_run" dispatches OPs one by one according to the sequence in the +// Program, while "serial_run" ensures that all Ops are scheduled in a singal +// thread. In standalone executor, "sequential_run" is also "serial_run", while +// "serial_run" is not necessarily "sequential_run". +PADDLE_DEFINE_EXPORTED_bool(new_executor_sequential_run, + false, + "Enable sequential execution for standalone " + "executor, only applied to GPU OPs."); + namespace paddle { namespace framework { namespace interpreter { @@ -43,7 +53,7 @@ const std::string StringizeDownstreamMap( } const std::map>& DependencyBuilder::Build( - const std::vector& instructions, bool is_sequential_run) { + const std::vector& instructions) { PADDLE_ENFORCE_EQ( is_build_, false, @@ -56,7 +66,7 @@ const std::map>& DependencyBuilder::Build( BuildOpHappensBefore(); ShrinkDownstreamMap(); - if (is_sequential_run) { + if (FLAGS_new_executor_sequential_run) { AddDependencyForSequentialRun(); } diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h index ca7331d4b78e4..ec1119e701da3 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h @@ -36,7 +36,7 @@ class DependencyBuilder { // build op dependencies and return the mapping from op to its downstream-op // set const std::map>& Build( - const std::vector& instructions, bool is_sequential_run); + const std::vector& instructions); const std::map>& OpDownstreamMap() const; diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 920fec72bd43a..1018da460eef4 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -70,21 +70,30 @@ inline std::string RunTypeToString(DownstreamRunType run_type) { } void StreamAnalyzer::ConstructEvents( - const DependencyBuilder& dependency_builder, std::vector* instructions) const { + std::vector cross_step_merged_instructions = *instructions; + for (const Instruction& instr : *instructions) { + cross_step_merged_instructions.emplace_back(instr); + } + + DependencyBuilder dependency_builder; + dependency_builder.Build(cross_step_merged_instructions); + const std::map>& downstream_map = dependency_builder.OpDownstreamMap(); - const size_t instr_num = instructions->size(); + const size_t instr_num = cross_step_merged_instructions.size(); std::vector>> run_type_info( instr_num, std::vector>( - /*number_of_run_type = */ 3)); // instr_id -> run_type -> + /*number_of_run_type = */ 2)); // instr_id -> run_type -> // next_instr_id - AnalyseAllRunType(*instructions, downstream_map, &run_type_info); + AnalyseAllRunType( + cross_step_merged_instructions, downstream_map, &run_type_info); std::map>> event_info; // DeviceContext -> waiter_instr_id -> recorder_instr_ids - AnalyseAllEventInfo(*instructions, run_type_info, &event_info); + AnalyseAllEventInfo( + cross_step_merged_instructions, run_type_info, &event_info); ShrinkEventInfo(dependency_builder, &event_info); // Construct events @@ -93,7 +102,17 @@ void StreamAnalyzer::ConstructEvents( for (auto& waiter_item : context_item.second) { size_t waiter_instr_id = waiter_item.first; std::set& recorder_instr_ids = waiter_item.second; + + if (waiter_instr_id >= instructions->size()) { + waiter_instr_id -= instructions->size(); + } + for (size_t recorder_instr_id : recorder_instr_ids) { + // Redundant record + if (recorder_instr_id >= instructions->size()) { + continue; + } + Instruction& recorder_instr = instructions->at(recorder_instr_id); Instruction& waiter_instr = instructions->at(waiter_instr_id); platform::DeviceType waiter_type = GetWaiterType(waiter_instr); diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index b9a228869d4c9..de0e6c741c245 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -37,8 +37,7 @@ class StreamAnalyzer { ~StreamAnalyzer() {} - void ConstructEvents(const DependencyBuilder& dependency_builder, - std::vector* instructions) const; + void ConstructEvents(std::vector* instructions) const; platform::DeviceContext* ParseDeviceContext( const OpFuncNode& op_func_node) const; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 070230af4d786..a0aa82102e315 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -33,15 +33,6 @@ #endif #include "paddle/phi/backends/device_manager.h" -// The difference between "sequential_run" and "serial_run": -// "sequential_run" dispatches OPs one by one according to the sequence in the -// Program, while "serial_run" ensures that all Ops are scheduled in a singal -// thread. In standalone executor, "sequential_run" is also "serial_run", while -// "serial_run" is not necessarily "sequential_run". -PADDLE_DEFINE_EXPORTED_bool(new_executor_sequential_run, - false, - "Enable sequential execution for standalone " - "executor, only applied to GPU OPs."); PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, false, "Use inplace in new executor"); @@ -519,9 +510,7 @@ void InterpreterCore::BuildOperatorDependences() { // and set the dependecy_count_ size_t instr_num = vec_instruction_.size(); dependecy_count_.resize(instr_num); - auto downstream_map = dependency_builder_.Build( - vec_instruction_, - /*is_sequential_run=*/FLAGS_new_executor_sequential_run); + auto downstream_map = dependency_builder_.Build(vec_instruction_); for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) { Instruction& cur_instr = vec_instruction_[instr_id]; @@ -588,7 +577,13 @@ void InterpreterCore::Convert( BuildOperatorDependences(); - stream_analyzer_.ConstructEvents(dependency_builder_, &vec_instruction_); + // NOTE(Ruibiao): For cross-step stream synchronization, an event may be + // recorded in the first step and waited in the second step. So, in the first + // step, the WaitEvent may be called without RecordEvent. Considering that + // before the first call to RecordEvent, an Event represents an empty set of + // work and WaitEvent always return succeed immediately, we omit the + // prelude-record for the first step here. + stream_analyzer_.ConstructEvents(&vec_instruction_); // add event for the input var of jit program, since there are async copied // from gpu_pinned place to gpu place on compute stream. diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index f42270f34a220..9fb4e0b7eebaf 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -301,7 +301,7 @@ class Instruction { void AddEventToRecord(std::shared_ptr event, platform::DeviceType waiter_type) { - event_to_record_ = std::make_unique(id_, event, waiter_type); + event_to_record_ = std::make_shared(id_, event, waiter_type); } void AddEventToWait(size_t instr_id, @@ -379,7 +379,7 @@ class Instruction { std::vector next_instrs_in_different_thread; std::vector next_instrs_in_same_thread; - std::unique_ptr event_to_record_; + std::shared_ptr event_to_record_; std::vector events_to_wait_; OpFuncNode op_func_node_; diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt b/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt index ee215ebf27a39..a983215420043 100644 --- a/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt @@ -5,49 +5,13 @@ file( string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") foreach(target ${TEST_INTERP_CASES}) - py_test_modules( - ${target} - MODULES - ${target} - ENVS - FLAGS_host_trace_level=10 - FLAGS_static_executor_perfstat_filepath=./perfstat - FLAGS_allocator_strategy=auto_growth - FLAGS_use_stream_safe_cuda_allocator=true - FLAGS_fast_eager_deletion_mode=false - FLAGS_eager_delete_tensor_gb=0) - - py_test_modules( - ${target}_non_eager_deletion - MODULES - ${target} - ENVS - FLAGS_allocator_strategy=auto_growth - FLAGS_use_stream_safe_cuda_allocator=true - FLAGS_fast_eager_deletion_mode=false - FLAGS_eager_delete_tensor_gb=0.000001) - - py_test_modules( - ${target}_fast_gc - MODULES - ${target} - ENVS - FLAGS_allocator_strategy=auto_growth - FLAGS_use_stream_safe_cuda_allocator=true - FLAGS_fast_eager_deletion_mode=true - FLAGS_eager_delete_tensor_gb=0) - - py_test_modules( - ${target}_fast_gc_non_eager_deletion - MODULES - ${target} - ENVS - FLAGS_allocator_strategy=auto_growth - FLAGS_use_stream_safe_cuda_allocator=true - FLAGS_fast_eager_deletion_mode=true - FLAGS_eager_delete_tensor_gb=0.000001) + py_test_modules(${target} MODULES ${target}) endforeach() +py_test_modules( + test_standalone_executor_no_fast_gc MODULES test_standalone_executor ENVS + FLAGS_fast_eager_deletion_mode=false) + py_test_modules( test_standalone_executor_sequential_run MODULES test_standalone_executor ENVS FLAGS_new_executor_sequential_run=true) @@ -56,5 +20,8 @@ py_test_modules( test_standalone_executor_serial_run MODULES test_standalone_executor ENVS FLAGS_new_executor_serial_run=true) -py_test_modules(test_convert_graph_to_program MODULES test_standalone_executor - ENVS FLAGS_CONVERT_GRAPH_TO_PROGRAM=true) +py_test_modules( + test_standalone_executor_stats MODULES test_standalone_executor ENVS + FLAGS_host_trace_level=10 FLAGS_static_executor_perfstat_filepath=./perfstat) + +set_tests_properties(test_standalone_cross_step_overlap PROPERTIES TIMEOUT 30) diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cross_step_overlap.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cross_step_overlap.py new file mode 100644 index 0000000000000..a4fe9f9d25849 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cross_step_overlap.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import static + +paddle.enable_static() + + +class TestCrossStepOverlap(unittest.TestCase): + def setUp(self): + self.shape = [16, 513, 513, 19] + self.x_value = 2 + self.y_value = 3 + self.overlap_op_num = 1500 + self.step_num = 3 + + def test_cross_step_overlap(self): + if not paddle.fluid.core.is_compiled_with_cuda(): + return + + # In this test case, z=x+y is calculated in the default stream, + # and at the same time, numerous reduce_min ops that output to y + # are executed in another stream (i.e., the custom stream). + # These reduce_min ops are carefully designed that their kernel + # calculation will overlap with the fill_constant kernels (output + # to x and y) in the next step, and therefore cross-step multi-stream + # synchronization is required. An Event should be recorded after the + # last reduce_min in the first step and waited before the fill_constant + # in the second step. Otherwise, the result of z will be wrong. + program = static.Program() + with static.program_guard(program): + x = paddle.full( + self.shape, fill_value=self.x_value, dtype='float64' + ) + y = paddle.full( + self.shape, fill_value=self.y_value, dtype='float64' + ) + z = paddle.add(x, y) + + block = program.global_block() + block.var(x.name).desc.set_persistable(True) + block.var(y.name).desc.set_persistable(True) + for i in range(self.overlap_op_num): + block.append_op( + type='reduce_min', + inputs={'X': x.name}, + outputs={'Out': y.name}, + attrs={'axis': 0, 'keepdim': True}, + ) + block.ops[-1].dist_attr.execution_stream = "custom" + + exe = static.Executor() + results = [] + for i in range(self.step_num): + result = exe.run(program, fetch_list=[z]) + results.append(result) + + for result in results: + self.assertAlmostEqual( + np.sum(result), + (self.x_value + self.y_value) * np.prod(self.shape), + ) + + +if __name__ == "__main__": + unittest.main()