From 3f51a3620b1c9ce4666f9b7b1776fe60af0941a7 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 1 Dec 2022 16:03:06 +0800 Subject: [PATCH 01/24] [AutoParallel] recompute tuning --- .../distributed/auto_parallel/constants.py | 2 +- .../distributed/auto_parallel/dist_op.py | 24 ++++ .../distributed/auto_parallel/engine.py | 13 +- .../distributed/auto_parallel/interface.py | 12 +- .../distributed/auto_parallel/strategy.py | 4 + .../auto_parallel/tuner/algorithms.py | 90 +++++++++++++- .../distributed/auto_parallel/tuner/config.py | 44 +++---- .../auto_parallel/tuner/optimization_tuner.py | 7 +- .../auto_parallel/tuner/profiler.py | 1 + .../paddle/distributed/auto_parallel/utils.py | 42 ++++--- .../passes/auto_parallel_recompute.py | 117 +++++++----------- .../auto_parallel/test_selective_recompute.py | 70 ++++++----- 12 files changed, 260 insertions(+), 166 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 857245b9be425..da7750e4114c7 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -54,7 +54,7 @@ def set_field_default_config(category, field, default_value): ######################################### RECOMPUTE = "recompute" set_field_default_config(RECOMPUTE, "enable", False) -set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "checkpoints", []) set_field_default_config(RECOMPUTE, "no_recompute_segments", []) set_field_default_config(RECOMPUTE, "enable_tuning", False) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index 80141730bc1a1..212a1d27dd795 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -373,3 +373,27 @@ def __call__(self, *args, **kwargs): default_dist_ctx.add_dist_op_for_program(dist_op) return output + + +# class RecomputeOperatorHelper: +# def __init__(self, op): +# self._op = op + +# def __call__(self, *args, **kwargs): +# default_prog = paddle.fluid.default_main_program() +# cur_block = default_prog.current_block() +# op_size = len(cur_block.ops) +# output = self._op(*args, **kwargs) +# new_op_size = len(cur_block.ops) + +# from .dist_context import get_default_distributed_context + +# default_dist_ctx = get_default_distributed_context() +# for idx in range(op_size, new_op_size - 1): +# op = cur_block.ops[idx] +# dist_op = DistributedOperator(op) +# dist_op.dist_attr.is_recompute = True + +# default_dist_ctx.add_dist_op_for_program(dist_op) + +# return output diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 8e27b9aac6c70..18509251bfb75 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -49,7 +49,11 @@ DistributedDataLoader, ) from .strategy import Strategy -from .process_group import new_process_group, get_all_process_groups +from .process_group import ( + new_process_group, + get_all_process_groups, + get_world_process_group, +) from .dist_context import DistributedContext, get_default_distributed_context from .interface import CollectionNames, get_collection from .cost.estimate_cost import get_cost_from_engine @@ -610,7 +614,9 @@ def _build(self, mode): if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_ckpts(self._model, self._strategy) + auto_utils.set_recompute_ckpts( + self._model, self._strategy, serial_main_prog + ) self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, @@ -650,7 +656,6 @@ def _optimization_tuning(self, mode, dataset, batch_size): from .tuner.optimization_tuner import OptimizationTuner self._optimization_tuner = OptimizationTuner( - self._tuning.to_dict(), self._dist_contexts[mode], dataset, self._inputs_spec, @@ -658,7 +663,7 @@ def _optimization_tuning(self, mode, dataset, batch_size): batch_size=batch_size, rank=self._cur_rank, ) - + print("[engine] world_ranks:", get_world_process_group().ranks) self._optimization_tuner.tune() if self._tuning.run_after_tuning: diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index b85d85011a1fa..9e9153ba038be 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -195,13 +195,7 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): return op -_g_recompute_idx = -1 - - def recompute(op): - global _g_recompute_idx - _g_recompute_idx += 1 - class RecomputeOperator: def __init__(self, op): self._op = op @@ -213,11 +207,9 @@ def __call__(self, *args, **kwargs): output = self._op(*args, **kwargs) new_op_size = len(cur_block.ops) - for idx in range(op_size, new_op_size): + for idx in range(op_size, new_op_size - 1): op = cur_block.ops[idx] - op._set_attr( - 'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx) - ) + op._set_attr('op_namescope', "/auto_parallel/rc") return output diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index dcfd453f63a33..cfb8cde75b03a 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -72,6 +72,10 @@ def __deepcopy__(self, memo): setattr(result, k, copy.deepcopy(v, memo)) return result + def get(self, k, d=None): + result_dict = self.to_dict() + return result_dict.get(k, d) + class RecomputeConfig(BaseConfig): def __init__(self, config_dict=None): diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index efc3358ebe41a..2b0306a6a973d 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -19,6 +19,7 @@ from ..utils import get_logger from .trial import TrialStatus from .trial import OptimizationTunerTrial as Trial +from ...passes.auto_parallel_recompute import RecomputeState class AlgorithmBase(ABC): @@ -54,7 +55,7 @@ def changed_configs(self): def collect_model_info(self, main_prog, startup_prog): """ Collect the model static info (from programs) that could be used to - pruning candidate trials and saving tuning time.For instance, + pruning candidate trials and saving tuning time. For instance, model info like number of model parameters and activation memory could be used to prune candidated trial and decide the next trial. """ @@ -116,7 +117,7 @@ def _init_spaces(self): self._max_stage = 3 self._trial_idx = 0 - stage_range = self._config.sharding.to_dict().get("tuning_range", None) + stage_range = self._config.sharding.get("tuning_range", None) if stage_range: assert set(stage_range).issubset( set([0, 1, 2, 3]) @@ -157,3 +158,88 @@ def update(self, results): ) else: self._trial_idx += 1 + + +@register_algor("recompute") +class ReccomputeCheckpointAlgorithm(AlgorithmBase): + def __init__(self, config): + super().__init__(config) + self._changed_configs = ["recompute"] + + def collect_model_info(self, main_prog, startup_prog): + checkpoints = self._config.recompute.get("checkpoints", []) + no_recompute_segments = self._config.recompute.get( + "no_recompute_segments", [] + ) + + rc_state = RecomputeState( + main_prog.global_block(), main_prog.global_block().ops + ) + rc_state.build_stats() + checkpoints = rc_state.sort_checkpoints(checkpoints) + segments = rc_state.get_recompute_segments( + checkpoints, is_logging=False + ) + + self._total_num_trial = len(segments) - len(no_recompute_segments) + self._total_segments = list(range(len(segments))) + self._tuning_segments = list( + set(self._total_segments) - set(no_recompute_segments) + ) + + def _init_spaces(self): + self._trial_idx = 0 + self._recompute_mode = "all" + + def next_trial(self): + if self._recompute_mode == "all": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + name = "trial-recompute-all-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "none": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.no_recompute_segments = self._total_segments + name = "trial-recompute-none-segments" + return Trial(new_strategy, name, self.changed_configs) + elif ( + self._recompute_mode == "part" + and self._trial_idx < self._total_num_trial + ): + index = int( + len(self._tuning_segments) * pow(0.5, self._trial_idx + 1) + ) + new_no_recompute = self._tuning_segments[:index] + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.no_recompute_segments.extend(new_no_recompute) + name = "trial-recompute-part-segments [{}]".format(self._trial_idx) + return Trial(new_strategy, name, self.changed_configs) + else: + return Trial(None, None, None, status=TrialStatus.STOPPED) + + def update(self, results): + + et = results.get("ErrorType", None) + if et and et == "ResourceExhaustedError": + self._trial_idx = self._total_num_trial + if self._recompute_mode == "all": + self._logger.info( + "Last trial is failed with OOM, all remaining trials are pruned to save time !" + ) + elif self._recompute_mode == "none": + self._logger.info( + "Last trial is failed with OOM, all remaining trials are pruned to save time !" + ) + else: + self._logger.info( + "Last trial is failed with OOM, all remaining trials are pruned to save time !" + ) + elif self._recompute_mode == "all": + self._recompute_mode = "none" + elif self._recompute_mode == "none": + self._recompute_mode = "part" + else: + self._trial_idx += 1 diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index 7bb9d4f18bcef..c93baa074fa99 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -32,14 +32,11 @@ class TuningConfig: tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific """ - def __init__(self, user_config, strategy): + def __init__(self, strategy): if not isinstance(strategy, Strategy): raise TypeError("'strategy' must be object of class `Strategy`.") - if not user_config: - user_config = {} - self._tuning_passes_name = set() self._dist_strategy = copy.deepcopy(strategy) self._mode = None @@ -50,7 +47,7 @@ def __init__(self, user_config, strategy): self._early_stop = None self._verbose = None - self._initialize(user_config) + self._initialize() @property def mode(self): @@ -89,21 +86,17 @@ def dist_strategy(self): return self._dist_strategy # initialize config with user define value or default value - def _initialize(self, user_config): - - self._mode = user_config.get("mode", "PROFILE") - - self._profile_start_step = user_config.get("profile_start_step", 10) - - self._profile_end_step = user_config.get("profile_end_step", 30) - - self._max_num_trial = user_config.get("max_num_trial", 50) - - self._early_stop = user_config.get("early_stop", None) + def _initialize(self): + tuning_strategy = self._dist_strategy.tuning - self._verbose = user_config.get("verbose", False) + self._mode = tuning_strategy.get("mode", "PROFILE") + self._profile_start_step = tuning_strategy.get("profile_start_step", 10) + self._profile_end_step = tuning_strategy.get("profile_end_step", 30) + self._max_num_trial = tuning_strategy.get("max_num_trial", 50) + self._early_stop = tuning_strategy.get("early_stop", None) + self._verbose = tuning_strategy.get("verbose", False) - project_dir = user_config.get("project_dir", None) + project_dir = tuning_strategy.get("project_dir", None) if not project_dir: project_dir = os.path.join(os.getcwd(), "OptimizationTuning") self._project_dir = project_dir @@ -116,15 +109,14 @@ def _initialize(self, user_config): # TODO distinguish different args of each passes self._tuning_passes_name.add(p) - config_name = p - p_dict = getattr(self._dist_strategy, config_name) - self.__dict__[config_name] = p_dict + p_strategy = getattr(self._dist_strategy, p) + self.__dict__[p] = p_strategy - # TODO verify the user defined configs - user_config_for_pass = user_config.get(p, None) - if user_config_for_pass: - for k, v in user_config_for_pass.items(): - self.__dict__[config_name][k] = v + # # TODO verify the user defined configs + # tuning_config_for_pass = tuning_strategy.get(p, None) + # if tuning_config_for_pass: + # for k, v in tuning_config_for_pass.items(): + # self.__dict__[p][k] = v # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned def __getattr__(self, item): diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 3cd58f2c00402..72a8a04e48cf1 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -35,7 +35,6 @@ from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.process_group import ( - clear_all_process_groups, get_all_process_groups, ) from paddle.distributed.auto_parallel.utils import debug_program @@ -94,6 +93,7 @@ def get_metric(results): def parse_results(results): + print("results:", results) if results['Throughtput'] > 0: return "Throughtput: {} step / s.".format(results['Throughtput']) et = results.get("ErrorType", None) @@ -107,7 +107,7 @@ def parse_results(results): # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): - clear_all_process_groups() + # clear_all_process_groups() new_dist_context = DistributedContext() new_dist_context._serial_main_program = ( @@ -193,7 +193,6 @@ class OptimizationTuner: def __init__( self, - user_configs, dist_context, dataset, inputs_spec, @@ -202,7 +201,7 @@ def __init__( rank, ): - self._config = TuningConfig(user_configs, dist_context._strategy) + self._config = TuningConfig(dist_context.strategy) # should not modify dist context from calling function self._baseline_dist_context = _copy_context(dist_context) self._baseline_completer = Completer(self._baseline_dist_context) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 1aeafbea76410..bd2b2e6d8b518 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -222,6 +222,7 @@ def profiler(args): with open(args.ctx_filename, 'rb') as f: profile_ctx = pickle.load(f, encoding='latin1') + print(profile_ctx) init_comm(profile_ctx) main_program, startup_program, loss_var = load_programs(profile_ctx) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index be4c68d97d840..dead1491a8525 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -23,12 +23,10 @@ import paddle.fluid.core as core from paddle.fluid.framework import Variable -from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.distributed.auto_parallel.process_group import ( - get_all_process_groups, -) from paddle.fluid.io import is_parameter, is_belong_to_optimizer -from paddle.distributed.auto_parallel.dist_attribute import ( + +from .process_group import get_all_process_groups +from .dist_attribute import ( TensorDistributedAttribute, OperatorDistributedAttribute, ) @@ -1894,11 +1892,28 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): server_socket.close() -def set_recompute_ckpts(model, strategy): - from .interface import _g_recompute_idx +def _is_recompute_op(op): + return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( + 'op_namescope' + ) + - if _g_recompute_idx > -1: - return +def get_checkpoints_from_program(program): + pass + + ops = program.global_block().ops + if not any([_is_recompute_op(op) for op in ops]): + return [] + + checkpoints = [] + for idx, op in enumerate(ops): + if not _is_recompute_op(op): + checkpoints.extend(op.output_arg_names) + + return checkpoints + + +def set_recompute_ckpts(model, strategy, program): recompute = strategy.recompute if not recompute.enable: @@ -1919,12 +1934,9 @@ def set_recompute_ckpts(model, strategy): exact_ckpts = recompute.checkpoints # modify strategy - recompute.checkpoints = exact_ckpts[:] - logs = { - 'Model Class': model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts, - } - logging.info(logs) + recompute.checkpoints = exact_ckpts[:] or get_checkpoints_from_program( + program + ) def get_input_split_info(cur_rank, var, dist_context): diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 5bdbe9d2dd5d9..8d88da8e9616a 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -32,12 +32,6 @@ ) -def _to_be_recomputed(op): - return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr( - 'op_namescope' - ) - - class RecomputeState(ProgramStats): def __init__(self, block, ops): super().__init__(block=block, ops=ops) @@ -45,8 +39,6 @@ def __init__(self, block, ops): self._ops = ops # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} self.var_op_deps = {} - # {segment_name: op_idx} - self.seg_op_deps = {} def build_stats(self): for i, op in enumerate(self._ops): @@ -66,63 +58,38 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - if not _to_be_recomputed(op): - continue - - seg_name = op.attr('op_namescope') - if seg_name not in self.seg_op_deps: - self.seg_op_deps[seg_name] = [i] - else: - assert ( - self.seg_op_deps[seg_name][-1] + 1 == i - ), "The recompute segment's ops should be continuous" - self.seg_op_deps[seg_name].extend([i]) - def get_recompute_segments( - self, checkpoints_list=None, no_recompute_segments=[] + self, checkpoints, no_recompute_segments=[], is_logging=True ): - """get recompute segments and checkpoints""" + """get recompute segments from checkpoints""" segments = [] - checkpoints = checkpoints_list or [] - - if len(checkpoints) == 0: - # the segments is marked by `auto.recompute()` api - for segment_idx in self.seg_op_deps.values(): - if len(segment_idx) == 1: + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in self.var_op_deps: + start_idx += 1 continue - segments.append([segment_idx[0], segment_idx[-1] + 1]) - checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names) - else: - # the segments is marked by `strategy.checkpoints` api - start_idx = -1 - pre_segment_end_idx = -1 - while start_idx + 1 < len(checkpoints): - if start_idx == -1: - ckpt_name = checkpoints[start_idx + 1] - if ckpt_name not in self.var_op_deps: - start_idx += 1 - continue - op_idx_list = self.var_op_deps[ckpt_name][ - "var_as_output_ops" - ] - if op_idx_list: - segments.append([0, max(op_idx_list) + 1]) - else: - flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"] + if op_idx_list and max(op_idx_list) > 0: + segments.append([0, max(op_idx_list) + 1]) + else: + flag, min_idx, max_idx = self.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + ) + if flag: + min_idx = self._update_segment_start( + min_idx, pre_segment_end_idx ) - if flag: - min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx - ) - segments.append([min_idx, max_idx + 1]) - else: - logging.info( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) + segments.append([min_idx, max_idx + 1]) + else: + logging.info( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 ) - start_idx += 1 + ) + start_idx += 1 if no_recompute_segments: for i in reversed(sorted(no_recompute_segments)): @@ -133,16 +100,19 @@ def get_recompute_segments( ) segments.pop(i) + if not is_logging: + return segments + for i, (idx1, idx2) in enumerate(segments): - logging.info("recompute segment[{}]".format(i)) - logging.info( + print("recompute segment[{}]".format(i)) + print( "segment start op: [{}]: [{}] [{}]".format( self._ops[idx1].desc.type(), self._ops[idx1].desc.input_arg_names(), self._ops[idx1].desc.output_arg_names(), ) ) - logging.info( + print( "segment end op: [{}]: [{}] [{}]".format( self._ops[idx2 - 1].desc.type(), self._ops[idx2 - 1].desc.input_arg_names(), @@ -150,10 +120,7 @@ def get_recompute_segments( ) ) - return segments, checkpoints - - def is_recompute(self): - return any([_to_be_recomputed(op) for op in self._ops]) + return segments def modify_forward_desc_for_recompute(self, dist_context): """ @@ -209,7 +176,6 @@ def modify_forward_desc_for_recompute(self, dist_context): outputs={"Out": seed_var}, attrs={"seed": seed, "force_cpu": True}, ) - seed_op._set_attr('op_namescope', cur_op.attr('op_namescope')) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( seed_op, ref_process_mesh, ref_dims_mapping, dist_context @@ -291,13 +257,14 @@ def __init__(self): self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) - self.set_attr("no_recompute_segments", []) def _check_self(self): if self.get_attr("dist_context") is None: return False if self.get_attr("loss") is None: return False + if self.get_attr("checkpoints") is None: + return False return True def _check_conflict(self, other_pass): @@ -317,19 +284,18 @@ def _apply_single_impl(self, main_program, startup_program, context): # 1. build recompute state rc_state = RecomputeState(main_block, op_path) - if not rc_state.is_recompute() and not checkpoints: - return - # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints or []) - segments, checkpoints = rc_state.get_recompute_segments( + checkpoints = rc_state.sort_checkpoints(checkpoints) + segments = rc_state.get_recompute_segments( checkpoints, no_recompute_segments ) - if segments == [] or checkpoints == []: + if segments == []: return + print("segments:", segments) + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: @@ -337,7 +303,7 @@ def _apply_single_impl(self, main_program, startup_program, context): rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) cross_vars = set(vars_should_be_hold) - set(checkpoints) - logging.info( + print( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars @@ -347,6 +313,7 @@ def _apply_single_impl(self, main_program, startup_program, context): vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold = list(set(vars_should_be_hold)) vars_in_memory = vars_should_be_hold + checkpoints + print("The vars hold in memory: [{}]".format(list(set(vars_in_memory)))) # 4. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py index 97e175a39801a..16a6b37cfb0f0 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py @@ -39,7 +39,7 @@ def generate_model(use_new_recompute, recompute_granularity): gpt = GPTModel( vocab_size=1000, hidden_size=64, - num_hidden_layers=2, + num_hidden_layers=16, num_attention_heads=8, intermediate_size=256, hidden_act="gelu", @@ -69,7 +69,15 @@ def apply_pass(use_recompute=False, no_recompute_segments=[]): if use_recompute: recompute = strategy.recompute recompute.enable = True + recompute.enable_tuning = True recompute.no_recompute_segments = no_recompute_segments + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 5 + tuning.run_after_tuning = True + tuning.verbose = True return strategy @@ -128,47 +136,51 @@ def recompute_vars(self, program): def test_recompute_pass(self): # mp2 training - mp_engine = self.get_engine() - history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) - mp_losses = np.array(history.history["loss"]) + # mp_engine = self.get_engine() + # history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + # mp_losses = np.array(history.history["loss"]) # mp2 recompute with old api - rc4_engine = self.get_engine(True, False) - history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size) - rc4_losses = np.array(history.history["loss"]) - self.check_results(mp_losses, rc4_losses) + # rc4_engine = self.get_engine(True, False) + # history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size) + # print("***"*30) + # print(rc4_engine.main_program) + # rc4_losses = np.array(history.history["loss"]) + # self.check_results(mp_losses, rc4_losses) # mp2 recompute core_attn - rc1_engine = self.get_engine(True, True, "core_attn", [0]) - history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) - rc1_losses = np.array(history.history["loss"]) - self.check_results(mp_losses, rc1_losses) + # rc1_engine = self.get_engine(True, True, "core_attn", [0]) + # history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) + # rc1_losses = np.array(history.history["loss"]) + # self.check_results(mp_losses, rc1_losses) - # mp2 recompute full_attn - rc2_engine = self.get_engine(True, True, "full_attn") - history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size) - rc2_losses = np.array(history.history["loss"]) - self.check_results(mp_losses, rc2_losses) + # # mp2 recompute full_attn + # rc2_engine = self.get_engine(True, True, "full_attn") + # history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size) + # rc2_losses = np.array(history.history["loss"]) + # self.check_results(mp_losses, rc2_losses) # mp2 recompute full rc3_engine = self.get_engine(True, True, "full") - history = rc3_engine.fit(self.dataset, 3, batch_size=self.batch_size) + history = rc3_engine._tune(self.dataset, 3, batch_size=self.batch_size) + print("***" * 30) + print(rc3_engine.main_program) rc3_losses = np.array(history.history["loss"]) - self.check_results(mp_losses, rc3_losses) + # self.check_results(mp_losses, rc3_losses) - rc0_vars = self.recompute_vars(mp_engine.main_program) - rc1_vars = self.recompute_vars(rc1_engine.main_program) - rc2_vars = self.recompute_vars(rc2_engine.main_program) - rc3_vars = self.recompute_vars(rc3_engine.main_program) + # rc0_vars = self.recompute_vars(mp_engine.main_program) + # rc1_vars = self.recompute_vars(rc1_engine.main_program) + # rc2_vars = self.recompute_vars(rc2_engine.main_program) + # rc3_vars = self.recompute_vars(rc3_engine.main_program) - assert rc0_vars == [] - assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars) + # assert rc0_vars == [] + # assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars) - def test_recompute_pass_error(self): + # def test_recompute_pass_error(self): - with self.assertRaises(AssertionError): - rc_engine = self.get_engine(True, True, "full", [2]) - history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + # with self.assertRaises(AssertionError): + # rc_engine = self.get_engine(True, True, "full", [2]) + # history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) if __name__ == "__main__": From 1de9545b8087c47b3c09941ea9f922f6ee132bed Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 1 Dec 2022 16:29:12 +0800 Subject: [PATCH 02/24] fix conflict --- .../paddle/distributed/auto_parallel/utils.py | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 21557cd5f841d..d61d76f28ef88 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1408,6 +1408,30 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) +def naive_set_dist_op_attr_for_program_by_mesh( + new_op, process_mesh, ctx, is_recompute=False +): + # hack to skip coalesce var for dist attr + if not is_recompute: + return + assert process_mesh is not None + + new_op_dist_attr = OperatorDistributedAttribute() + + for input_varname in new_op.desc.input_arg_names(): + var = ctx.serial_main_program.global_block().var(input_varname) + mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping + new_op_dist_attr.set_input_dims_mapping(input_varname, mapping) + for output_varname in new_op.desc.output_arg_names(): + var = ctx.serial_main_program.global_block().var(output_varname) + mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping + new_op_dist_attr.set_output_dims_mapping(output_varname, mapping) + + new_op_dist_attr.process_mesh = process_mesh + new_op_dist_attr.is_recompute = is_recompute + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr @@ -2118,3 +2142,120 @@ def _copy_dist_attr_from_cpp_for_graph(dist_context): py_dist_attr = dist_context.get_op_dist_attr_for_graph(node) cpp_dist_attr = node.op().dist_attr _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr) + + +def insert_dependencies_for_two_ops( + block, + idx, + prior_op, + posterior_op, + dist_context, + is_recompute=False, + sync=False, +): + """ + dependency: prior_op should be run before posterior_op + """ + + assert ( + len(prior_op.output_arg_names) >= 1 + ), "first op of dependency should at least have one output. [{}]".format( + str(prior_op) + ) + assert ( + len(posterior_op.input_arg_names) >= 1 + ), "second op of dependency should at least have one input. [{}]".format( + str(posterior_op) + ) + prior_op_mesh = dist_context.get_op_dist_attr_for_program( + prior_op + ).process_mesh + posterior_mesh = dist_context.get_op_dist_attr_for_program( + posterior_op + ).process_mesh + assert ( + prior_op_mesh == posterior_mesh + ), "two ops of dependency should have same mesh but got [{}] and [{}]".format( + str(prior_op_mesh), str(posterior_mesh) + ) + + def _select_best_depend_var(vars): + + vars_with_numels = [(var, get_var_numel(var)) for var in vars] + vars_with_numels.sort(key=lambda x: x[1]) + + return vars_with_numels[-1][0] + + first_var = _select_best_depend_var( + [block.var(name) for name in prior_op.output_arg_names] + ) + second_var = _select_best_depend_var( + [block.var(name) for name in posterior_op.input_arg_names] + ) + + return insert_dependencies_for_two_vars( + block, + idx, + first_var, + second_var, + dist_context, + OpRole.Backward, + prior_op_mesh, + is_recompute, + sync, + ) + + +def insert_dependencies_for_two_vars( + block, + idx, + prior_var, + post_var, + dist_context, + oprole, + process_mesh=None, + is_recompute=False, + sync=False, +): + """ + dependency: op that generates prior_var should be run before op that generates post_var + """ + assert block.has_var(prior_var.name) + assert block.has_var(post_var.name) + if process_mesh is None: + process_mesh = dist_context.get_tensor_dist_attr_for_program( + post_var + ).process_mesh + assert process_mesh is not None + + depend_op = block._insert_op_without_sync( + idx, + type='nop', + inputs={ + "X": prior_var, + }, + outputs={"Out": post_var}, + ) + # depend_op.desc.set_type("depend") + depend_op._set_attr(OP_ROLE_KEY, oprole) + # depend_op.desc.set_input("Dep", [first_var.name]) + # self.desc.set_output(out_proto.name, out_arg_names) + + naive_set_dist_op_attr_for_program_by_mesh( + depend_op, process_mesh, dist_context, is_recompute + ) + + if sync: + block._sync_with_cpp() + + return depend_op + + +def use_standalone_executor(): + return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [ + 1, + '1', + True, + 'True', + 'true', + ] From 0dacac22a193d9ce62d072517056e74a6eb0f074 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 1 Dec 2022 16:35:16 +0800 Subject: [PATCH 03/24] update comment --- .../auto_parallel/tuner/optimization_tuner.py | 3 +- .../passes/auto_parallel_recompute.py | 29 ++++++++----------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index ab8440d22efcc..98d40bdab50dc 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -31,6 +31,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.process_group import ( + clear_all_process_groups, get_all_process_groups, ) from paddle.distributed.auto_parallel.reshard import Resharder @@ -109,7 +110,7 @@ def parse_results(results): # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): - # clear_all_process_groups() + clear_all_process_groups() new_dist_context = DistributedContext() new_dist_context._serial_main_program = ( diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 3c5ad79ea57c0..0ff3b8d3b8140 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -14,16 +14,6 @@ import logging -from paddle.distributed.auto_parallel.dist_attribute import ( - OperatorDistributedAttribute, -) -from paddle.distributed.auto_parallel.utils import ( - get_loss_op, - insert_dependencies_for_two_ops, - naive_set_dist_op_attr_for_program_by_mesh_and_mapping, - set_dist_op_desc_original_id, - set_var_dist_attr, -) from paddle.fluid import core from paddle.fluid import framework as framework from paddle.fluid import unique_name @@ -35,6 +25,14 @@ _rename_arg_, ) +from ..auto_parallel.dist_attribute import OperatorDistributedAttribute +from ..auto_parallel.utils import ( + get_loss_op, + insert_dependencies_for_two_ops, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_dist_op_desc_original_id, + set_var_dist_attr, +) from .pass_base import PassBase, register_pass @@ -110,15 +108,15 @@ def get_recompute_segments( return segments for i, (idx1, idx2) in enumerate(segments): - print("recompute segment[{}]".format(i)) - print( + logging.info("recompute segment[{}]".format(i)) + logging.info( "segment start op: [{}]: [{}] [{}]".format( self._ops[idx1].desc.type(), self._ops[idx1].desc.input_arg_names(), self._ops[idx1].desc.output_arg_names(), ) ) - print( + logging.info( "segment end op: [{}]: [{}] [{}]".format( self._ops[idx2 - 1].desc.type(), self._ops[idx2 - 1].desc.input_arg_names(), @@ -301,8 +299,6 @@ def _apply_single_impl(self, main_program, startup_program, context): if segments == []: return - print("segments:", segments) - # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: @@ -310,7 +306,7 @@ def _apply_single_impl(self, main_program, startup_program, context): rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) cross_vars = set(vars_should_be_hold) - set(checkpoints) - print( + logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars @@ -320,7 +316,6 @@ def _apply_single_impl(self, main_program, startup_program, context): vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold = list(set(vars_should_be_hold)) vars_in_memory = vars_should_be_hold + checkpoints - print("The vars hold in memory: [{}]".format(list(set(vars_in_memory)))) # 4. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX From 5c8d6e130e0179829475a1b1dd44f253871d3173 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 1 Dec 2022 17:31:55 +0800 Subject: [PATCH 04/24] bug fix --- python/paddle/distributed/auto_parallel/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index da86609092e7d..680db45e5e7dd 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -609,7 +609,9 @@ def _build(self, mode): if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_ckpts(self._model, self._strategy) + auto_utils.set_recompute_ckpts( + self._model, self._strategy, serial_main_prog + ) self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, From a91fdec63451dd93b7863603a79bd4b4bb5650ad Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 1 Dec 2022 22:22:18 +0800 Subject: [PATCH 05/24] update rc algo --- .../auto_parallel/tuner/algorithms.py | 103 ++++++++++-------- .../auto_parallel/tuner/optimization_tuner.py | 1 - .../auto_parallel/tuner/profiler.py | 12 +- .../paddle/distributed/auto_parallel/utils.py | 1 - .../unittests/auto_parallel/get_gpt_model.py | 7 +- .../auto_parallel/test_selective_recompute.py | 29 ++--- 6 files changed, 78 insertions(+), 75 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 3490da12ec5dc..940e0cff97ecb 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -168,9 +168,6 @@ def __init__(self, config): def collect_model_info(self, main_prog, startup_prog): checkpoints = self._config.recompute.get("checkpoints", []) - no_recompute_segments = self._config.recompute.get( - "no_recompute_segments", [] - ) rc_state = RecomputeState( main_prog.global_block(), main_prog.global_block().ops @@ -181,65 +178,75 @@ def collect_model_info(self, main_prog, startup_prog): checkpoints, is_logging=False ) - self._total_num_trial = len(segments) - len(no_recompute_segments) - self._total_segments = list(range(len(segments))) - self._tuning_segments = list( - set(self._total_segments) - set(no_recompute_segments) - ) + self._total_num_trial = len(segments) + self._tuning_segments = list(range(len(segments))) + self._trail_left = 0 + self._trail_right = len(segments) - 1 + self._trial_idx = int(0 + (len(segments)) / 2) def _init_spaces(self): - self._trial_idx = 0 self._recompute_mode = "all" def next_trial(self): - if self._recompute_mode == "all": - self._recompute_flag = False - new_strategy = copy.deepcopy(self._config.dist_strategy) - name = "trial-recompute-all-segments" - return Trial(new_strategy, name, self.changed_configs) - elif self._recompute_mode == "none": - self._recompute_flag = False - new_strategy = copy.deepcopy(self._config.dist_strategy) - recompute = new_strategy.recompute - recompute.no_recompute_segments = self._total_segments - name = "trial-recompute-none-segments" - return Trial(new_strategy, name, self.changed_configs) - elif ( - self._recompute_mode == "part" - and self._trial_idx < self._total_num_trial - ): - index = int( - len(self._tuning_segments) * pow(0.5, self._trial_idx + 1) - ) - new_no_recompute = self._tuning_segments[:index] - new_strategy = copy.deepcopy(self._config.dist_strategy) - recompute = new_strategy.recompute - recompute.no_recompute_segments.extend(new_no_recompute) - name = "trial-recompute-part-segments [{}]".format(self._trial_idx) - return Trial(new_strategy, name, self.changed_configs) + if self._trial_idx < self._total_num_trial: + if self._recompute_mode == "all": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + name = "trial-recompute-all-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "none": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.enable = False + recompute.checkpoints = [] + name = "trial-recompute-none-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "part": + new_no_recompute = self._tuning_segments[: self._trial_idx] + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.no_recompute_segments.extend(new_no_recompute) + name = "trial-recompute-part-segments-idx{}".format( + self._trial_idx + ) + return Trial(new_strategy, name, self.changed_configs) else: return Trial(None, None, None, status=TrialStatus.STOPPED) def update(self, results): et = results.get("ErrorType", None) - if et and et == "ResourceExhaustedError": - self._trial_idx = self._total_num_trial - if self._recompute_mode == "all": - self._logger.info( - "Last trial is failed with OOM, all remaining trials are pruned to save time !" - ) - elif self._recompute_mode == "none": + if self._recompute_mode == "all": + if et and et == "ResourceExhaustedError": + self._trial_idx = self._total_num_trial self._logger.info( - "Last trial is failed with OOM, all remaining trials are pruned to save time !" + "Recompute all candidate segments is failed with OOM, please reduce model size or batch size." ) else: + self._recompute_mode = "none" + elif self._recompute_mode == "none": + if et and et == "ResourceExhaustedError": + self._recompute_mode = "part" + else: + self._trial_idx = self._total_num_trial self._logger.info( - "Last trial is failed with OOM, all remaining trials are pruned to save time !" + "Recompute is unnecessary for this model size, which will reduce the flops." ) - elif self._recompute_mode == "all": - self._recompute_mode = "none" - elif self._recompute_mode == "none": - self._recompute_mode = "part" else: - self._trial_idx += 1 + if self._trail_left >= self._trail_right: + self._trial_idx = self._total_num_trial + elif et and et == "ResourceExhaustedError": + self._trail_left = self._trail_left + self._trail_right = self._trial_idx - 1 + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) + else: + self._trail_left = self._trial_idx + 1 + self._trail_right = self._trail_right + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 98d40bdab50dc..10da5896f62b6 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -96,7 +96,6 @@ def get_metric(results): def parse_results(results): - print("results:", results) if results['Throughtput'] > 0: return "Throughtput: {} step / s.".format(results['Throughtput']) et = results.get("ErrorType", None) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 9c0ac057acfb3..65ba8ca063d65 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -221,7 +221,6 @@ def profiler(args): with open(args.ctx_filename, 'rb') as f: profile_ctx = pickle.load(f, encoding='latin1') - print(profile_ctx) init_comm(profile_ctx) main_program, startup_program, loss_var = load_programs(profile_ctx) @@ -232,13 +231,12 @@ def profiler(args): exe = get_executor() - exe.run(startup_program) - - # profile main - duration = 0 - eval_step = 0 - data_loader._inner_dataloader.start() try: + exe.run(startup_program) + # profile main + duration = 0 + eval_step = 0 + data_loader._inner_dataloader.start() while eval_step < args.profile_end_step: start_time = time.time() diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index d61d76f28ef88..1069d7b269230 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1927,7 +1927,6 @@ def _is_recompute_op(op): def get_checkpoints_from_program(program): - pass ops = program.global_block().ops if not any([_is_recompute_op(op) for op in ops]): diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index b77d42653abdb..35bf1a323d15c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -28,12 +28,9 @@ GPTPretrainingCriterion, ) -sequence_len = 512 -vocab_size = 1000 - class FakeDataset(paddle.io.Dataset): - def __init__(self, num_samples): + def __init__(self, num_samples, vocab_size=1000, sequence_len=512): self.num_samples = num_samples self.sequence_len = sequence_len self.vocab_size = vocab_size @@ -57,7 +54,7 @@ def __len__(self): return self.num_samples -def create_data_holder(batch_size): +def create_data_holder(batch_size, vocab_size=1000, sequence_len=512): tokens = paddle.static.InputSpec( name="tokens", shape=[batch_size, sequence_len], dtype='int64' ) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py index 42caacda6268a..ce3dba22c14b7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py @@ -38,11 +38,11 @@ def generate_model(use_new_recompute, recompute_granularity): modeling._global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"]) gpt = GPTModel( - vocab_size=1000, - hidden_size=64, - num_hidden_layers=16, - num_attention_heads=8, - intermediate_size=256, + vocab_size=50304, + hidden_size=2048, + num_hidden_layers=48, + num_attention_heads=16, + intermediate_size=2048 * 4, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, @@ -57,7 +57,7 @@ def generate_model(use_new_recompute, recompute_granularity): recompute_granularity=recompute_granularity, ) model = GPTForPretraining( - gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 + gpt, vocab_size=50304, hidden_size=2048, initializer_range=0.02 ) criterion = GPTPretrainingCriterion() return model, criterion @@ -91,10 +91,14 @@ class TestRecomputePassWithRecomputeAPI(unittest.TestCase): def setUp(self): self.rtol = 1e-6 self.atol = 1e-8 - self.batch_size = 1 - self.batch_num = 2 + self.batch_size = 8 + self.batch_num = 200 self.clip_norm = 0.2 - self.dataset = FakeDataset(self.batch_size * self.batch_num) + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) def init(self, engine): paddle.seed(2022) @@ -163,10 +167,9 @@ def test_recompute_pass(self): # mp2 recompute full rc3_engine = self.get_engine(True, True, "full") - history = rc3_engine._tune(self.dataset, 3, batch_size=self.batch_size) - print("***" * 30) - print(rc3_engine.main_program) - rc3_losses = np.array(history.history["loss"]) + rc3_engine._tune(self.dataset, 3, batch_size=self.batch_size) + # print("***" * 30) + # rc3_losses = np.array(history.history["loss"]) # self.check_results(mp_losses, rc3_losses) # rc0_vars = self.recompute_vars(mp_engine.main_program) From 8c72bc1e5e213df73d2ec133c4f0b0e570517351 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 2 Dec 2022 10:32:22 +0800 Subject: [PATCH 06/24] tiny fix --- python/paddle/distributed/passes/auto_parallel_recompute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 0ff3b8d3b8140..c931525614a18 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -262,13 +262,14 @@ def __init__(self): self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) + self.set_attr("no_recompute_segments", []) def _check_self(self): if self.get_attr("dist_context") is None: return False if self.get_attr("loss") is None: return False - if self.get_attr("checkpoints") is None: + if not self.get_attr("checkpoints"): return False return True From 5f8177382e6001ef911329ce0b856c1752021083 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 2 Dec 2022 20:00:24 +0800 Subject: [PATCH 07/24] fix clear process_group --- .../auto_parallel/tuner/algorithms.py | 6 +- .../auto_parallel/tuner/optimization_tuner.py | 6 + .../passes/auto_parallel_recompute.py | 45 ++++---- .../unittests/auto_parallel/CMakeLists.txt | 2 + .../auto_parallel/test_selective_recompute.py | 91 +++++++-------- .../auto_parallel/test_tuning_recompute.py | 104 ++++++++++++++++++ 6 files changed, 174 insertions(+), 80 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 940e0cff97ecb..b5e6c6c0aa85e 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -174,9 +174,7 @@ def collect_model_info(self, main_prog, startup_prog): ) rc_state.build_stats() checkpoints = rc_state.sort_checkpoints(checkpoints) - segments = rc_state.get_recompute_segments( - checkpoints, is_logging=False - ) + segments = rc_state.get_recompute_segments(checkpoints) self._total_num_trial = len(segments) self._tuning_segments = list(range(len(segments))) @@ -231,7 +229,7 @@ def update(self, results): else: self._trial_idx = self._total_num_trial self._logger.info( - "Recompute is unnecessary for this model size, which will reduce the flops." + "Recompute is unnecessary for this model size, which will reduce the Throughtput." ) else: if self._trail_left >= self._trail_right: diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 10da5896f62b6..ded9d9f4d509e 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.process_group import ( clear_all_process_groups, get_all_process_groups, + new_process_group, ) from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.utils import ( @@ -109,7 +110,12 @@ def parse_results(results): # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): + # clear all process groups and recover the world process group clear_all_process_groups() + ranks = [] + for process_mesh in ref_dist_context._process_meshes: + ranks.extend(process_mesh.processes) + new_process_group(list(set(ranks))) new_dist_context = DistributedContext() new_dist_context._serial_main_program = ( diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index c931525614a18..adc202aff91ff 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -44,6 +44,10 @@ def __init__(self, block, ops): # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} self.var_op_deps = {} + @property + def ops(self): + return self._ops + def build_stats(self): for i, op in enumerate(self._ops): for name in op.desc.input_arg_names(): @@ -62,9 +66,7 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - def get_recompute_segments( - self, checkpoints, no_recompute_segments=[], is_logging=True - ): + def get_recompute_segments(self, checkpoints, no_recompute_segments=[]): """get recompute segments from checkpoints""" segments = [] start_idx = -1 @@ -104,26 +106,6 @@ def get_recompute_segments( ) segments.pop(i) - if not is_logging: - return segments - - for i, (idx1, idx2) in enumerate(segments): - logging.info("recompute segment[{}]".format(i)) - logging.info( - "segment start op: [{}]: [{}] [{}]".format( - self._ops[idx1].desc.type(), - self._ops[idx1].desc.input_arg_names(), - self._ops[idx1].desc.output_arg_names(), - ) - ) - logging.info( - "segment end op: [{}]: [{}] [{}]".format( - self._ops[idx2 - 1].desc.type(), - self._ops[idx2 - 1].desc.input_arg_names(), - self._ops[idx2 - 1].desc.output_arg_names(), - ) - ) - return segments def modify_forward_desc_for_recompute(self, dist_context): @@ -300,6 +282,23 @@ def _apply_single_impl(self, main_program, startup_program, context): if segments == []: return + for i, (idx1, idx2) in enumerate(segments): + logging.info("recompute segment[{}]".format(i)) + logging.info( + "segment start op: [{}]: [{}] [{}]".format( + rc_state.ops[idx1].desc.type(), + rc_state.ops[idx1].desc.input_arg_names(), + rc_state.ops[idx1].desc.output_arg_names(), + ) + ) + logging.info( + "segment end op: [{}]: [{}] [{}]".format( + rc_state.ops[idx2 - 1].desc.type(), + rc_state.ops[idx2 - 1].desc.input_arg_names(), + rc_state.ops[idx2 - 1].desc.output_arg_names(), + ) + ) + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 18fad917b6839..147c64653b427 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) + py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py index ce3dba22c14b7..64563314ac2ca 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_selective_recompute.py @@ -38,11 +38,11 @@ def generate_model(use_new_recompute, recompute_granularity): modeling._global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"]) gpt = GPTModel( - vocab_size=50304, - hidden_size=2048, - num_hidden_layers=48, - num_attention_heads=16, - intermediate_size=2048 * 4, + vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, @@ -57,7 +57,7 @@ def generate_model(use_new_recompute, recompute_granularity): recompute_granularity=recompute_granularity, ) model = GPTForPretraining( - gpt, vocab_size=50304, hidden_size=2048, initializer_range=0.02 + gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 ) criterion = GPTPretrainingCriterion() return model, criterion @@ -70,15 +70,7 @@ def apply_pass(use_recompute=False, no_recompute_segments=[]): if use_recompute: recompute = strategy.recompute recompute.enable = True - recompute.enable_tuning = True recompute.no_recompute_segments = no_recompute_segments - - tuning = strategy.tuning - tuning.enable = True - tuning.profile_start_step = 1 - tuning.profile_end_step = 5 - tuning.run_after_tuning = True - tuning.verbose = True return strategy @@ -91,14 +83,10 @@ class TestRecomputePassWithRecomputeAPI(unittest.TestCase): def setUp(self): self.rtol = 1e-6 self.atol = 1e-8 - self.batch_size = 8 - self.batch_num = 200 + self.batch_size = 1 + self.batch_num = 2 self.clip_norm = 0.2 - self.dataset = FakeDataset( - self.batch_size * self.batch_num, - vocab_size=50304, - sequence_len=1024, - ) + self.dataset = FakeDataset(self.batch_size * self.batch_num) def init(self, engine): paddle.seed(2022) @@ -141,50 +129,47 @@ def recompute_vars(self, program): def test_recompute_pass(self): # mp2 training - # mp_engine = self.get_engine() - # history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) - # mp_losses = np.array(history.history["loss"]) + mp_engine = self.get_engine() + history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(history.history["loss"]) # mp2 recompute with old api - # rc4_engine = self.get_engine(True, False) - # history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size) - # print("***"*30) - # print(rc4_engine.main_program) - # rc4_losses = np.array(history.history["loss"]) - # self.check_results(mp_losses, rc4_losses) + rc4_engine = self.get_engine(True, False) + history = rc4_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc4_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc4_losses) # mp2 recompute core_attn - # rc1_engine = self.get_engine(True, True, "core_attn", [0]) - # history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) - # rc1_losses = np.array(history.history["loss"]) - # self.check_results(mp_losses, rc1_losses) + rc1_engine = self.get_engine(True, True, "core_attn", [0]) + history = rc1_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc1_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc1_losses) - # # mp2 recompute full_attn - # rc2_engine = self.get_engine(True, True, "full_attn") - # history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size) - # rc2_losses = np.array(history.history["loss"]) - # self.check_results(mp_losses, rc2_losses) + # mp2 recompute full_attn + rc2_engine = self.get_engine(True, True, "full_attn") + history = rc2_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc2_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc2_losses) # mp2 recompute full rc3_engine = self.get_engine(True, True, "full") - rc3_engine._tune(self.dataset, 3, batch_size=self.batch_size) - # print("***" * 30) - # rc3_losses = np.array(history.history["loss"]) - # self.check_results(mp_losses, rc3_losses) + history = rc3_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc3_losses = np.array(history.history["loss"]) + self.check_results(mp_losses, rc3_losses) - # rc0_vars = self.recompute_vars(mp_engine.main_program) - # rc1_vars = self.recompute_vars(rc1_engine.main_program) - # rc2_vars = self.recompute_vars(rc2_engine.main_program) - # rc3_vars = self.recompute_vars(rc3_engine.main_program) + rc0_vars = self.recompute_vars(mp_engine.main_program) + rc1_vars = self.recompute_vars(rc1_engine.main_program) + rc2_vars = self.recompute_vars(rc2_engine.main_program) + rc3_vars = self.recompute_vars(rc3_engine.main_program) - # assert rc0_vars == [] - # assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars) + assert rc0_vars == [] + assert len(rc1_vars) < len(rc2_vars) and len(rc2_vars) < len(rc3_vars) - # def test_recompute_pass_error(self): + def test_recompute_pass_error(self): - # with self.assertRaises(AssertionError): - # rc_engine = self.get_engine(True, True, "full", [2]) - # history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + with self.assertRaises(AssertionError): + rc_engine = self.get_engine(True, True, "full", [2]) + history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py new file mode 100644 index 0000000000000..cf0f379c31d62 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from get_gpt_model import FakeDataset + +import paddle +from paddle.distributed.fleet import auto + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + + +def generate_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + + gpt = GPTModel( + vocab_size=50304, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=1024 * 4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=True, + recompute_granularity="full", + ) + model = GPTForPretraining( + gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + recompute = strategy.recompute + recompute.enable = True + recompute.enable_tuning = True + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 5 + tuning.run_after_tuning = True + tuning.verbose = True + return strategy + + +class TestRecomputePassTuning(unittest.TestCase): + def setUp(self): + + self.batch_size = 8 + self.batch_num = 200 + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) + + def test_recompute_pass(self): + + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model() + + engine = auto.Engine(model, loss, opt, strategy=strategy) + engine._tune(self.dataset, 3, batch_size=self.batch_size) + + assert not engine._dist_contexts['train'].strategy.recompute.enable + + +if __name__ == "__main__": + unittest.main() From 0fbb7ff3b6a0f4e7a03d386a04a2d766377d5a3e Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 2 Dec 2022 20:01:59 +0800 Subject: [PATCH 08/24] remove comment --- .../distributed/auto_parallel/dist_op.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index 4ce67ef16a266..484bf45111dc9 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -377,27 +377,3 @@ def __call__(self, *args, **kwargs): default_dist_ctx.add_dist_op_for_program(dist_op) return output - - -# class RecomputeOperatorHelper: -# def __init__(self, op): -# self._op = op - -# def __call__(self, *args, **kwargs): -# default_prog = paddle.fluid.default_main_program() -# cur_block = default_prog.current_block() -# op_size = len(cur_block.ops) -# output = self._op(*args, **kwargs) -# new_op_size = len(cur_block.ops) - -# from .dist_context import get_default_distributed_context - -# default_dist_ctx = get_default_distributed_context() -# for idx in range(op_size, new_op_size - 1): -# op = cur_block.ops[idx] -# dist_op = DistributedOperator(op) -# dist_op.dist_attr.is_recompute = True - -# default_dist_ctx.add_dist_op_for_program(dist_op) - -# return output From fe13bb3fc1ba61f1c57acae2c45743c26ae3d2fb Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 2 Dec 2022 20:50:05 +0800 Subject: [PATCH 09/24] update segment print --- .../passes/auto_parallel_recompute.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index adc202aff91ff..ee6505a9f1ce0 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -44,10 +44,6 @@ def __init__(self, block, ops): # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} self.var_op_deps = {} - @property - def ops(self): - return self._ops - def build_stats(self): for i, op in enumerate(self._ops): for name in op.desc.input_arg_names(): @@ -286,16 +282,16 @@ def _apply_single_impl(self, main_program, startup_program, context): logging.info("recompute segment[{}]".format(i)) logging.info( "segment start op: [{}]: [{}] [{}]".format( - rc_state.ops[idx1].desc.type(), - rc_state.ops[idx1].desc.input_arg_names(), - rc_state.ops[idx1].desc.output_arg_names(), + rc_state._ops[idx1].desc.type(), + rc_state._ops[idx1].desc.input_arg_names(), + rc_state._ops[idx1].desc.output_arg_names(), ) ) logging.info( "segment end op: [{}]: [{}] [{}]".format( - rc_state.ops[idx2 - 1].desc.type(), - rc_state.ops[idx2 - 1].desc.input_arg_names(), - rc_state.ops[idx2 - 1].desc.output_arg_names(), + rc_state._ops[idx2 - 1].desc.type(), + rc_state._ops[idx2 - 1].desc.input_arg_names(), + rc_state._ops[idx2 - 1].desc.output_arg_names(), ) ) From 95717d0e60145af406961581bd4b113624362f14 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 6 Dec 2022 11:44:16 +0800 Subject: [PATCH 10/24] fix import OpRole --- python/paddle/distributed/passes/auto_parallel_fp16.py | 3 +-- python/paddle/distributed/passes/auto_parallel_grad_clip.py | 3 +-- .../paddle/distributed/passes/auto_parallel_gradient_merge.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7aed31b01ec2b..f3c2db8ed1f42 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -22,13 +22,12 @@ get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_backward_op, is_forward_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import unique_name from paddle.fluid.contrib.mixed_precision.fp16_utils import ( AutoMixedPrecisionLists, diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index af5259680e4a5..7258eca661d63 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from ..auto_parallel.dist_attribute import ( OperatorDistributedAttribute, @@ -25,8 +26,6 @@ from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.reshard import Resharder from ..auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, _get_comm_group, insert_dependencies_for_two_vars, is_gradient_clip_op, diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 8ac3492c2b14d..01d19722d07a6 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -19,12 +19,11 @@ get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import layers from paddle.fluid.framework import device_guard from paddle.framework import core From ce325c2f171cd1f9aedfb0ca1fe7ac1fed8063eb Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 6 Dec 2022 20:00:19 +0800 Subject: [PATCH 11/24] adapt amp pass and grad_clip pass for opt_tuner --- .../auto_parallel/tuner/optimization_tuner.py | 32 +++++++++++++++---- .../auto_parallel/tuner/profiler.py | 11 ++++--- .../passes/auto_parallel_recompute.py | 2 +- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index ded9d9f4d509e..a5c7be63b7b41 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -41,7 +41,7 @@ set_grad_var_shape, ) from paddle.distributed.passes import PassContext, new_pass -from paddle.fluid import program_guard +from paddle.fluid import program_guard, unique_name from paddle.fluid.backward import append_backward from ..utils import get_logger @@ -304,7 +304,6 @@ def _apply_optimization(self, trial): config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads - # TODO AMP Pass should not use loss var config["loss"] = dist_context.serial_loss config["input_data"] = ( @@ -317,13 +316,13 @@ def _apply_optimization(self, trial): auto_parallel_fp16_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() + dist_context._serial_loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_amp_pass.get_loss() + dist_context._serial_loss = auto_parallel_amp_pass.get_loss() if new_strategy.recompute.enable: config = copy.deepcopy(new_strategy.recompute.to_dict()) @@ -349,10 +348,11 @@ def _apply_optimization(self, trial): # Generate optimizer # FIXME should be remove from apply pass after pass support optimizers + optimizer = copy.deepcopy(dist_context.serial_optimizer) + dist_context._serial_optimizer = optimizer with program_guard(dist_main_prog, dist_startup_prog): - optimizer_ops = dist_context.serial_optimizer.apply_gradients( - dist_params_grads - ) + with unique_name.guard("opt_"): + optimizer_ops = optimizer.apply_gradients(dist_params_grads) completer.complete_update_annotation(dist_main_prog) # Do reshard process @@ -366,6 +366,13 @@ def _apply_optimization(self, trial): ) resharder.reshard() + config = {} + config["dist_context"] = dist_context + config["global_rank"] = self.rank + config["use_sharding"] = new_strategy.sharding.enable + dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) + dp_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) + if new_strategy.sharding.enable: config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context @@ -377,6 +384,17 @@ def _apply_optimization(self, trial): auto_parallel_sharding_pass.apply( [dist_main_prog], [dist_startup_prog], pass_context ) + dist_params_grads = pass_context.get_attr("params_grads") + + # gradient clip + config = copy.deepcopy(new_strategy.sharding.to_dict()) + config["dist_context"] = dist_context + config["params_grads"] = dist_params_grads + config["rank_id"] = self.rank + auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config) + auto_parallel_clip_pass.apply( + [dist_main_prog], [dist_startup_prog], pass_context + ) if new_strategy.gradient_merge.enable: config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 65ba8ca063d65..cdd4a0045c8c9 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -89,7 +89,7 @@ def init_process_groups(group_map, rank): # TODO should instantiate global group first all_process_groups = get_all_process_groups() for process_group in all_process_groups: - if process_group.id == 0 or rank not in process_group.ranks: + if rank not in process_group.ranks: continue print(process_group) process_group.instantiate() @@ -173,10 +173,11 @@ def init_comm(profile_ctx): genv = _get_global_env() genv = dist_env print( - "current process rank: {}, device_id: {}, ip: {}.", - genv.rank, - genv.device_id, - genv.current_endpoint, + "current process rank: {}, device_id: {}, ip: {}.".format( + genv.rank, + genv.device_id, + genv.current_endpoint, + ) ) # init nccl comm diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index ee6505a9f1ce0..416c9cc639345 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -86,7 +86,7 @@ def get_recompute_segments(self, checkpoints, no_recompute_segments=[]): ) segments.append([min_idx, max_idx + 1]) else: - logging.info( + logging.debug( "Could not recompute op range [{}] - [{}] ".format( min_idx, max_idx + 1 ) From a3e128f369e611550d8c9bb5ccab5f78d27cca7a Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 6 Dec 2022 21:17:59 +0800 Subject: [PATCH 12/24] update tuning config --- .../paddle/distributed/auto_parallel/constants.py | 4 +--- .../distributed/auto_parallel/tuner/config.py | 8 ++++---- .../auto_parallel/tuner/optimization_tuner.py | 6 +++--- python/paddle/distributed/auto_parallel/utils.py | 13 +++++++++---- .../auto_parallel/optimization_tuner_api.py | 2 +- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index da7750e4114c7..ce72304dc75cd 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -113,12 +113,10 @@ def set_field_default_config(category, field, default_value): # ######################################### TUNING = "tuning" set_field_default_config(TUNING, "enable", False) -set_field_default_config(TUNING, "batch_size", 1) -set_field_default_config(TUNING, "dataset", None) set_field_default_config(TUNING, "profile_start_step", 1) set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "run_after_tuning", True) -set_field_default_config(TUNING, "verbose", True) +set_field_default_config(TUNING, "debug", False) ######################################### # dataset configuration diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index 6196382f91010..78f94b87b360b 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -45,7 +45,7 @@ def __init__(self, strategy): self._project_dir = None self._max_num_trial = None self._early_stop = None - self._verbose = None + self._debug = None self._initialize() @@ -78,8 +78,8 @@ def early_stop(self): return self._early_stop @property - def verbose(self): - return self._verbose + def debug(self): + return self._debug @property def dist_strategy(self): @@ -94,7 +94,7 @@ def _initialize(self): self._profile_end_step = tuning_strategy.get("profile_end_step", 30) self._max_num_trial = tuning_strategy.get("max_num_trial", 50) self._early_stop = tuning_strategy.get("early_stop", None) - self._verbose = tuning_strategy.get("verbose", False) + self._debug = tuning_strategy.get("debug", False) project_dir = tuning_strategy.get("project_dir", None) if not project_dir: diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index a5c7be63b7b41..82348002133c0 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -269,7 +269,7 @@ def _build_programs_without_optimization(self): ) self._baseline_dist_context._params_grads = params_grads - if self._config.verbose: + if self._config.debug: baseline_dir = os.path.join(self.project_dir, "baseline") if not os.path.exists(baseline_dir): pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) @@ -511,7 +511,7 @@ def _profile_trial(self, trial): with open(ctx_path, 'wb') as f: pickle.dump(profile_ctx, f, protocol=4) - if self._config.verbose: + if self._config.debug: debug_program(trial.main_program, trial_dir, "main_program") debug_program(trial.startup_program, trial_dir, "startup_program") @@ -604,7 +604,7 @@ def clear(self): Clear the temporary file generated in tuning procedure. """ # TODO clear up zombie process created by tuning - if not self._config.verbose: + if not self._config.debug: for trial in self._finished_trials: trial_dir = self._get_trial_dir(trial) shutil.rmtree(trial_dir, ignore_errors=True) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 1069d7b269230..5021390d05ad5 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1950,10 +1950,15 @@ def set_recompute_ckpts(model, strategy, program): # TODO support more PaddleNLP/CV models here # extract ckpts by specific model if isinstance(model, paddle.nn.Layer): - if hasattr(model, "gpt") and model.__class__.__name__ in [ - 'GPTForPretraining', - 'GPTForPretrainingAuto', - ]: + if ( + hasattr(model, "gpt") + and model.__class__.__name__ + in [ + 'GPTForPretraining', + 'GPTForPretrainingAuto', + ] + and hasattr(model.gpt, "checkpoints") + ): exact_ckpts = model.gpt.checkpoints else: exact_ckpts = recompute.checkpoints diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index 10005008cdbe5..dfb554ac722d1 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -98,7 +98,7 @@ def train(fetch): tuning.profile_start_step = 1 tuning.profile_end_step = 5 tuning.run_after_tuning = True - tuning.verbose = True + tuning.debug = True dataset = MyDataset(batch_num * batch_size) engine = auto.Engine( From fb73b4d4527441605ef54a499c4bfa019978bb47 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Wed, 7 Dec 2022 10:23:56 +0800 Subject: [PATCH 13/24] fix import --- python/paddle/distributed/auto_parallel/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 5021390d05ad5..fd98275c22f0e 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,9 +22,9 @@ import numpy as np import paddle -import paddle.fluid.core as core from paddle.fluid.framework import Variable from paddle.fluid.io import is_belong_to_optimizer, is_parameter +from paddle.framework import core from .dist_attribute import ( OperatorDistributedAttribute, @@ -32,8 +32,8 @@ ) from .process_group import get_all_process_groups -OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OpRole = core.op_proto_and_checker_maker.OpRole +OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() __no_shape_var_type__ = [ core.VarDesc.VarType.READER, From f5f76d250bad2ea9499770036d12943b5d055ca3 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 8 Dec 2022 14:34:50 +0800 Subject: [PATCH 14/24] annotate recompute info on ops and upgrade recompute pass --- .../distributed/auto_parallel/engine.py | 4 +- .../distributed/auto_parallel/interface.py | 12 +- .../auto_parallel/tuner/algorithms.py | 18 +- .../paddle/distributed/auto_parallel/utils.py | 73 ++++-- .../passes/auto_parallel_recompute.py | 220 +++++++++--------- 5 files changed, 185 insertions(+), 142 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 680db45e5e7dd..fc49ea3c1aaaf 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -609,8 +609,8 @@ def _build(self, mode): if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_ckpts( - self._model, self._strategy, serial_main_prog + auto_utils.set_recompute_scope( + self._model, self._losses, self._strategy, serial_main_prog ) self._dist_contexts[mode] = DistributedContext( serial_main_prog, diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index d7d5c2ccb2b19..882b63b39395b 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -195,7 +195,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): return op +_g_recompute_idx = -1 + + def recompute(op): + global _g_recompute_idx + _g_recompute_idx += 1 + class RecomputeOperator: def __init__(self, op): self._op = op @@ -207,9 +213,11 @@ def __call__(self, *args, **kwargs): output = self._op(*args, **kwargs) new_op_size = len(cur_block.ops) - for idx in range(op_size, new_op_size - 1): + for idx in range(op_size, new_op_size): op = cur_block.ops[idx] - op._set_attr('op_namescope', "/auto_parallel/rc") + op._set_attr( + 'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx) + ) return output diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index b5e6c6c0aa85e..1fcbb88ab05eb 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -16,8 +16,7 @@ import logging from abc import ABC, abstractmethod -from ...passes.auto_parallel_recompute import RecomputeState -from ..utils import get_logger +from ..utils import get_logger, is_recompute_op from .trial import OptimizationTunerTrial as Trial from .trial import TrialStatus @@ -167,14 +166,14 @@ def __init__(self, config): self._changed_configs = ["recompute"] def collect_model_info(self, main_prog, startup_prog): - checkpoints = self._config.recompute.get("checkpoints", []) + segments = [] + for op in main_prog.global_block().ops: + if not is_recompute_op(op): + continue - rc_state = RecomputeState( - main_prog.global_block(), main_prog.global_block().ops - ) - rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints) - segments = rc_state.get_recompute_segments(checkpoints) + seg_name = op.attr('op_namescope') + if seg_name not in segments: + segments.append(seg_name) self._total_num_trial = len(segments) self._tuning_segments = list(range(len(segments))) @@ -197,7 +196,6 @@ def next_trial(self): new_strategy = copy.deepcopy(self._config.dist_strategy) recompute = new_strategy.recompute recompute.enable = False - recompute.checkpoints = [] name = "trial-recompute-none-segments" return Trial(new_strategy, name, self.changed_configs) elif self._recompute_mode == "part": diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index fd98275c22f0e..fd3e97809c4bf 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1920,27 +1920,17 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): server_socket.close() -def _is_recompute_op(op): +def is_recompute_op(op): return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( 'op_namescope' ) -def get_checkpoints_from_program(program): +def set_recompute_scope(model, losses, strategy, program): + from ..passes.auto_parallel_recompute import RecomputeState - ops = program.global_block().ops - if not any([_is_recompute_op(op) for op in ops]): - return [] - - checkpoints = [] - for idx, op in enumerate(ops): - if not _is_recompute_op(op): - checkpoints.extend(op.output_arg_names) - - return checkpoints - - -def set_recompute_ckpts(model, strategy, program): + if not losses: + return recompute = strategy.recompute if not recompute.enable: @@ -1949,6 +1939,7 @@ def set_recompute_ckpts(model, strategy, program): # NOTE: hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here # extract ckpts by specific model + ckpts = [] if isinstance(model, paddle.nn.Layer): if ( hasattr(model, "gpt") @@ -1959,16 +1950,54 @@ def set_recompute_ckpts(model, strategy, program): ] and hasattr(model.gpt, "checkpoints") ): - exact_ckpts = model.gpt.checkpoints + ckpts = model.gpt.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints - # modify strategy - recompute.checkpoints = exact_ckpts[:] or get_checkpoints_from_program( - program - ) + if not ckpts: + return + + block = program.global_block() + rc_state = RecomputeState(block, block.ops) + rc_state.build_stats() + checkpoints = rc_state.sort_checkpoints(ckpts) + + segments = [] + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in rc_state.var_op_deps: + start_idx += 1 + continue + op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"] + if op_idx_list and max(op_idx_list) > 0: + segments.append([0, max(op_idx_list) + 1]) + else: + flag, min_idx, max_idx = rc_state.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + ) + if flag: + min_idx = rc_state._update_segment_start( + min_idx, pre_segment_end_idx + ) + segments.append([min_idx, max_idx + 1]) + else: + logging.debug( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 + ) + ) + start_idx += 1 + + for i, segment in enumerate(segments): + for j in range(segment[0], segment[1]): + block.ops[j]._set_attr( + 'op_namescope', "/auto_parallel/rc_" + str(i) + ) def get_input_split_info(cur_rank, var, dist_context): diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 416c9cc639345..7fb4a688ccb98 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -14,9 +14,8 @@ import logging -from paddle.fluid import core -from paddle.fluid import framework as framework -from paddle.fluid import unique_name +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from paddle.fluid import core, framework, unique_name from paddle.fluid.backward import ( ProgramStats, _append_grad_suffix_, @@ -29,6 +28,8 @@ from ..auto_parallel.utils import ( get_loss_op, insert_dependencies_for_two_ops, + is_backward_op, + is_recompute_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_dist_op_desc_original_id, set_var_dist_attr, @@ -39,14 +40,24 @@ class RecomputeState(ProgramStats): def __init__(self, block, ops): super().__init__(block=block, ops=ops) - self._block = block - self._ops = ops - # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} - self.var_op_deps = {} - - def build_stats(self): - for i, op in enumerate(self._ops): - for name in op.desc.input_arg_names(): + self.seg_op_deps = {} + self._checkpoints = [] + self._reserved_vars = [] + + @property + def checkpoints(self): + return self._checkpoints + + @property + def reserved_vars(self): + return self._reserved_vars + + def build_states(self): + for i, op in enumerate(self.ops): + if is_backward_op(op): + break + + for name in op.input_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_input_ops"].extend([i]) else: @@ -54,7 +65,7 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [] - for name in op.desc.output_arg_names(): + for name in op.output_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_output_ops"].extend([i]) else: @@ -62,60 +73,31 @@ def build_stats(self): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - def get_recompute_segments(self, checkpoints, no_recompute_segments=[]): - """get recompute segments from checkpoints""" - segments = [] - start_idx = -1 - pre_segment_end_idx = -1 - while start_idx + 1 < len(checkpoints): - if start_idx == -1: - ckpt_name = checkpoints[start_idx + 1] - if ckpt_name not in self.var_op_deps: - start_idx += 1 - continue - op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"] - if op_idx_list and max(op_idx_list) > 0: - segments.append([0, max(op_idx_list) + 1]) - else: - flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]] - ) - if flag: - min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx - ) - segments.append([min_idx, max_idx + 1]) - else: - logging.debug( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) - ) - start_idx += 1 - - if no_recompute_segments: - for i in reversed(sorted(no_recompute_segments)): - assert i < len( - segments - ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( - i, len(segments) - ) - segments.pop(i) + if not is_recompute_op(op): + self._checkpoints.extend(op.output_arg_names) + continue - return segments + seg_name = op.attr('op_namescope') + if seg_name not in self.seg_op_deps: + self.seg_op_deps[seg_name] = [i] + else: + assert ( + self.seg_op_deps[seg_name][-1] + 1 == i + ), "The recompute segment's ops should be continuous" + self.seg_op_deps[seg_name].extend([i]) def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ - op_types = [op.desc.type() for op in self._ops] + op_types = [op.type for op in self.ops] if "dropout" not in op_types: return op_idx = 0 - while op_idx < len(self._ops): - cur_op = self._ops[op_idx] + while op_idx < len(self.ops): + cur_op = self.ops[op_idx] if "grad" in cur_op.type: break if cur_op.type != "dropout": @@ -124,6 +106,10 @@ def modify_forward_desc_for_recompute(self, dist_context): if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): op_idx += 1 continue + if cur_op.type == "seed": + self._reserved_vars.extend(cur_op.output_arg_names) + op_idx += 1 + continue cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) # insert seed op to guarantee that two dropout op have the same outputs @@ -131,6 +117,7 @@ def modify_forward_desc_for_recompute(self, dist_context): var_unique_name = unique_name.generate_with_ignorable_key( ".".join([op_unique_name, 'tmp']) ) + self._reserved_vars.append(var_unique_name) seed_var = self._block.create_var( name=var_unique_name, dtype='int32', @@ -165,7 +152,7 @@ def modify_forward_desc_for_recompute(self, dist_context): ) # modify dropout op's desc - self._ops.insert(op_idx, seed_op) + self.ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op._remove_attr("fix_seed") cur_op._remove_attr("seed") @@ -176,6 +163,24 @@ def modify_forward_desc_for_recompute(self, dist_context): self._block._sync_with_cpp() + def get_recompute_segments(self, no_recompute_segments=[]): + segments = [] + for segment_idx in self.seg_op_deps.values(): + if len(segment_idx) == 1: + continue + segments.append([segment_idx[0], segment_idx[-1] + 1]) + self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names) + + for i in reversed(sorted(no_recompute_segments)): + assert i < len( + segments + ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( + i, len(segments) + ) + segments.pop(i) + + return segments + def _find_op_index(block, cur_op): for idx in range(block.desc.op_size()): @@ -184,7 +189,7 @@ def _find_op_index(block, cur_op): return -1 -def _get_stop_gradients(program, no_grad_set): +def _get_stop_gradients(program, no_grad_set=None): """get no grad var""" if no_grad_set is None: no_grad_set = set() @@ -202,16 +207,15 @@ def _get_stop_gradients(program, no_grad_set): def _add_needed_descs_to_block( - descs, block, main_block, in_memory_vars, dist_context + descs, block, main_block, vars_should_be_hold, dist_context ): """ Get the recomputed ops which will insert the backward part """ if len(descs) == 0: return [] + result_descs = [] - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward for desc in descs: if isinstance(desc, framework.Operator): desc = desc.desc @@ -221,22 +225,29 @@ def _add_needed_descs_to_block( for name in desc.output_arg_names(): if main_block.has_var(name) and main_block.var(name).persistable: continue - if name not in in_memory_vars: + if name not in vars_should_be_hold: is_needed = True if is_needed: new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) set_dist_op_desc_original_id(new_op_desc, desc, dist_context) - new_op_desc._set_attr(op_role_attr_name, backward) + new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) result_descs.append(new_op_desc) return result_descs +def _find_op_path(main_program, loss, no_grad_set=None): + no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) + op_path = _find_op_path_( + main_program.global_block(), [loss], [], no_grad_set_name + ) + return op_path + + @register_pass("auto_parallel_recompute") class RecomputePass(PassBase): def __init__(self): super().__init__() - self.set_attr("checkpoints", None) self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) @@ -247,71 +258,64 @@ def _check_self(self): return False if self.get_attr("loss") is None: return False - if not self.get_attr("checkpoints"): - return False return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, context): - checkpoints = self.get_attr("checkpoints") - no_recompute_segments = self.get_attr("no_recompute_segments") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") + no_recompute_segments = self.get_attr("no_recompute_segments") self._dist_context = self.get_attr("dist_context") # 0. get op_path which is related to loss main_block = main_program.global_block() - no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) - op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) - + op_path = _find_op_path(main_program, loss, no_grad_set) # 1. build recompute state rc_state = RecomputeState(main_block, op_path) # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) - rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints) - segments = rc_state.get_recompute_segments( - checkpoints, no_recompute_segments - ) + rc_state.build_states() + segments = rc_state.get_recompute_segments(no_recompute_segments) if segments == []: return + print("segments:", segments) for i, (idx1, idx2) in enumerate(segments): - logging.info("recompute segment[{}]".format(i)) - logging.info( + print("recompute segment[{}]".format(i)) + print( "segment start op: [{}]: [{}] [{}]".format( - rc_state._ops[idx1].desc.type(), - rc_state._ops[idx1].desc.input_arg_names(), - rc_state._ops[idx1].desc.output_arg_names(), + rc_state.ops[idx1].type, + rc_state.ops[idx1].input_arg_names, + rc_state.ops[idx1].output_arg_names, ) ) - logging.info( + print( "segment end op: [{}]: [{}] [{}]".format( - rc_state._ops[idx2 - 1].desc.type(), - rc_state._ops[idx2 - 1].desc.input_arg_names(), - rc_state._ops[idx2 - 1].desc.output_arg_names(), + rc_state.ops[idx2 - 1].type, + rc_state.ops[idx2 - 1].input_arg_names, + rc_state.ops[idx2 - 1].output_arg_names, ) ) - # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) - cross_vars = set(vars_should_be_hold) - set(checkpoints) + cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints) logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars ) ) - vars_should_be_hold.extend(rc_state.get_reserved_vars()) + vars_should_be_hold.extend(rc_state.reserved_vars) vars_should_be_hold.extend(rc_state.get_input_nodes()) - vars_should_be_hold = list(set(vars_should_be_hold)) - vars_in_memory = vars_should_be_hold + checkpoints + vars_should_be_hold = list( + set(vars_should_be_hold) | set(rc_state.checkpoints) + ) # 4. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX @@ -322,27 +326,31 @@ def _apply_single_impl(self, main_program, startup_program, context): var_suffix = ".subprog_%d" % i for op in fwd_ops: input_and_output_names = [] - input_and_output_names.extend(op.desc.input_arg_names()) - input_and_output_names.extend(op.desc.output_arg_names()) + input_and_output_names.extend(op.input_arg_names) + input_and_output_names.extend(op.output_arg_names) + cur_op_dist_attr = ( self._dist_context.get_op_dist_attr_for_program(op) ) assert cur_op_dist_attr is not None + for name in input_and_output_names: - if main_block.var(name).persistable or name in checkpoints: - continue - if name in vars_should_be_hold: + if ( + main_block.var(name).persistable + or name in vars_should_be_hold + ): continue if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh - if name in op.desc.input_arg_names(): + try: ref_dims_mapping = ( cur_op_dist_attr.get_input_dims_mapping(name) ) - else: + except: ref_dims_mapping = ( cur_op_dist_attr.get_output_dims_mapping(name) ) + # record recomputed var's old_name and new_name (old_name.subprog_XXX) # create new var with new name var_name_dict[name] = name + var_suffix @@ -367,7 +375,7 @@ def _apply_single_impl(self, main_program, startup_program, context): fwd_ops, buffer_block, main_block, - vars_in_memory, + vars_should_be_hold, self._dist_context, ) # rename recomputed ops' input and output var name @@ -395,15 +403,15 @@ def _apply_single_impl(self, main_program, startup_program, context): grad_op._remove_attr("fix_seed") grad_op._remove_attr("seed") - # rename grad op's var_name which is not in 'vars_in_memory' - for key in var_name_dict: - if ( - key - not in grad_op.input_arg_names + grad_op.output_arg_names - ): + input_and_output_names = [] + input_and_output_names.extend(grad_op.input_arg_names) + input_and_output_names.extend(grad_op.output_arg_names) + + for varname in var_name_dict: + if varname not in input_and_output_names: continue self.reset_op_dist_attr(grad_op, var_name_dict) - _rename_arg_([grad_op.desc], key, var_name_dict[key]) + _rename_arg_([grad_op.desc], varname, var_name_dict[varname]) # insert recomputed ops original_id = grad_op.desc.original_id() @@ -462,13 +470,13 @@ def _apply_single_impl(self, main_program, startup_program, context): def reset_op_dist_attr(self, op, var_name_dict): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None - for input in op.desc.input_arg_names(): + for input in op.input_arg_names: if input in var_name_dict.keys(): in_dist_attr = op_dist_attr.get_input_dist_attr(input) op_dist_attr.set_input_dist_attr( var_name_dict[input], in_dist_attr ) - for output in op.desc.output_arg_names(): + for output in op.output_arg_names: if output in var_name_dict.keys(): out_dist_attr = op_dist_attr.get_output_dist_attr(output) op_dist_attr.set_output_dist_attr( From de60fec7ab5a040412eac6dfef66e96fafebac52 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 8 Dec 2022 14:45:49 +0800 Subject: [PATCH 15/24] add op_namescope for seed op --- .../passes/auto_parallel_recompute.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 7fb4a688ccb98..d75cd7694032e 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -86,6 +86,24 @@ def build_states(self): ), "The recompute segment's ops should be continuous" self.seg_op_deps[seg_name].extend([i]) + def get_recompute_segments(self, no_recompute_segments=[]): + segments = [] + for segment_idx in self.seg_op_deps.values(): + if len(segment_idx) == 1: + continue + segments.append([segment_idx[0], segment_idx[-1] + 1]) + self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names) + + for i in reversed(sorted(no_recompute_segments)): + assert i < len( + segments + ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( + i, len(segments) + ) + segments.pop(i) + + return segments + def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert @@ -118,7 +136,7 @@ def modify_forward_desc_for_recompute(self, dist_context): ".".join([op_unique_name, 'tmp']) ) self._reserved_vars.append(var_unique_name) - seed_var = self._block.create_var( + seed_var = self.block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, @@ -139,13 +157,14 @@ def modify_forward_desc_for_recompute(self, dist_context): else int(cur_op.attr("seed")) ) # TODO add dependency for seed op to ensure it be issued just before recompute. - seed_op = self._block._insert_op_without_sync( + seed_op = self.block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, outputs={"Out": seed_var}, attrs={"seed": seed, "force_cpu": True}, ) + seed_op._set_attr('op_namescope', cur_op.attr('op_namescope')) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( seed_op, ref_process_mesh, ref_dims_mapping, dist_context @@ -161,25 +180,7 @@ def modify_forward_desc_for_recompute(self, dist_context): ) op_idx += 2 - self._block._sync_with_cpp() - - def get_recompute_segments(self, no_recompute_segments=[]): - segments = [] - for segment_idx in self.seg_op_deps.values(): - if len(segment_idx) == 1: - continue - segments.append([segment_idx[0], segment_idx[-1] + 1]) - self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names) - - for i in reversed(sorted(no_recompute_segments)): - assert i < len( - segments - ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( - i, len(segments) - ) - segments.pop(i) - - return segments + self.block._sync_with_cpp() def _find_op_index(block, cur_op): From 7b47635919e7772571705da789ac06b9db5c89ca Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Thu, 8 Dec 2022 14:56:44 +0800 Subject: [PATCH 16/24] record reserved vars --- .../paddle/distributed/passes/auto_parallel_recompute.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index d75cd7694032e..52876bcce2546 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -118,14 +118,14 @@ def modify_forward_desc_for_recompute(self, dist_context): cur_op = self.ops[op_idx] if "grad" in cur_op.type: break - if cur_op.type != "dropout": + if cur_op.type == "seed": + self._reserved_vars.extend(cur_op.output_arg_names) op_idx += 1 continue - if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): + if cur_op.type != "dropout": op_idx += 1 continue - if cur_op.type == "seed": - self._reserved_vars.extend(cur_op.output_arg_names) + if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): op_idx += 1 continue From 159c60f3f2b062d663b3f13a542880fbb7173541 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 9 Dec 2022 10:26:04 +0800 Subject: [PATCH 17/24] fix recompute var's dist_attr --- python/paddle/distributed/passes/auto_parallel_recompute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 52876bcce2546..4a42058743aa5 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -343,11 +343,11 @@ def _apply_single_impl(self, main_program, startup_program, context): continue if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh - try: + if name in op.input_arg_names: ref_dims_mapping = ( cur_op_dist_attr.get_input_dims_mapping(name) ) - except: + else: ref_dims_mapping = ( cur_op_dist_attr.get_output_dims_mapping(name) ) From 8e019e618ea9ae275dfa05db62f16bd962d6fe1e Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Fri, 9 Dec 2022 13:46:30 +0800 Subject: [PATCH 18/24] fix strategy unittest --- .../fluid/tests/unittests/auto_parallel/test_strategy.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 8649c0f8dffcd..529d1d5f6255d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -24,7 +24,7 @@ def test_default_config(self): recompute = strategy.recompute self.assertEqual(recompute.enable, False) - self.assertIsNone(recompute.checkpoints) + self.assertEqual(recompute.checkpoints, []) amp = strategy.amp self.assertEqual(amp.enable, False) @@ -66,12 +66,10 @@ def test_default_config(self): tuning = strategy.tuning self.assertEqual(tuning.enable, False) - self.assertEqual(tuning.batch_size, 1) - self.assertIsNone(tuning.dataset) self.assertEqual(tuning.profile_start_step, 1) self.assertEqual(tuning.profile_end_step, 1) self.assertEqual(tuning.run_after_tuning, True) - self.assertEqual(tuning.verbose, True) + self.assertEqual(tuning.debug, False) def test_modify_config(self): strategy = auto.Strategy() From 4d4e1d60c2106b632c087b423441395cac39f445 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Mon, 12 Dec 2022 20:07:43 +0800 Subject: [PATCH 19/24] adapt for fp16 --- python/paddle/distributed/passes/auto_parallel_amp.py | 6 ++++++ python/paddle/distributed/passes/auto_parallel_fp16.py | 6 ++++++ .../paddle/distributed/passes/auto_parallel_recompute.py | 9 +++++---- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index e96cd4ec77d8f..cba613676d58d 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -226,6 +226,9 @@ def _insert_cast_op_forward( dist_context, out_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = self._block._insert_op_without_sync( idx, type="cast", @@ -236,6 +239,9 @@ def _insert_cast_op_forward( "out_dtype": out_var.dtype, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index f3c2db8ed1f42..0e834343e2800 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -416,6 +416,9 @@ def _insert_forward_cast_ops( dist_context, cast_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = block._insert_op_without_sync( idx, type="cast", @@ -427,6 +430,9 @@ def _insert_forward_cast_ops( OP_ROLE_KEY: OpRole.Forward, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 4a42058743aa5..d143c6c84d0d1 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -282,17 +282,18 @@ def _apply_single_impl(self, main_program, startup_program, context): if segments == []: return - print("segments:", segments) for i, (idx1, idx2) in enumerate(segments): - print("recompute segment[{}]".format(i)) - print( + logging.info( + "recompute segment[{}/{}]".format(i + 1, len(segments)) + ) + logging.info( "segment start op: [{}]: [{}] [{}]".format( rc_state.ops[idx1].type, rc_state.ops[idx1].input_arg_names, rc_state.ops[idx1].output_arg_names, ) ) - print( + logging.info( "segment end op: [{}]: [{}] [{}]".format( rc_state.ops[idx2 - 1].type, rc_state.ops[idx2 - 1].input_arg_names, From 1071160739f1bd9e7d811ded65b9863ddf2b2add Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Mon, 12 Dec 2022 22:09:31 +0800 Subject: [PATCH 20/24] update unittest --- .../paddle/distributed/auto_parallel/tuner/algorithms.py | 2 +- .../fluid/tests/unittests/auto_parallel/CMakeLists.txt | 2 +- .../unittests/auto_parallel/test_tuning_recompute.py | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 1fcbb88ab05eb..74e8f3e9ee3f1 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -179,7 +179,7 @@ def collect_model_info(self, main_prog, startup_prog): self._tuning_segments = list(range(len(segments))) self._trail_left = 0 self._trail_right = len(segments) - 1 - self._trial_idx = int(0 + (len(segments)) / 2) + self._trial_idx = int(0 + (len(segments) - 1) / 2) def _init_spaces(self): self._recompute_mode = "all" diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 35143858c8abc..d6ccfb8f6d57b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -75,7 +75,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) - set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 50) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 160) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py index cf0f379c31d62..5e925ee5c1bad 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -36,12 +36,12 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=2, + num_hidden_layers=12, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, max_position_embeddings=1024, type_vocab_size=1, initializer_range=0.02, @@ -97,8 +97,6 @@ def test_recompute_pass(self): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) - assert not engine._dist_contexts['train'].strategy.recompute.enable - if __name__ == "__main__": unittest.main() From a9062e7e5c7967a9064f99bd2f2e3e89aa3633ef Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Mon, 12 Dec 2022 22:36:17 +0800 Subject: [PATCH 21/24] revert copy opt --- .../distributed/auto_parallel/tuner/optimization_tuner.py | 6 +++--- .../paddle/distributed/passes/auto_parallel_recompute.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 82348002133c0..c3de081c752ba 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -348,11 +348,11 @@ def _apply_optimization(self, trial): # Generate optimizer # FIXME should be remove from apply pass after pass support optimizers - optimizer = copy.deepcopy(dist_context.serial_optimizer) - dist_context._serial_optimizer = optimizer with program_guard(dist_main_prog, dist_startup_prog): with unique_name.guard("opt_"): - optimizer_ops = optimizer.apply_gradients(dist_params_grads) + optimizer_ops = dist_context.serial_optimizer.apply_gradients( + dist_params_grads + ) completer.complete_update_annotation(dist_main_prog) # Do reshard process diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index d143c6c84d0d1..d99f335517a16 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -52,6 +52,9 @@ def checkpoints(self): def reserved_vars(self): return self._reserved_vars + def is_recompute(self): + return any([is_recompute_op(op) for op in self.ops]) + def build_states(self): for i, op in enumerate(self.ops): if is_backward_op(op): @@ -273,8 +276,12 @@ def _apply_single_impl(self, main_program, startup_program, context): # 0. get op_path which is related to loss main_block = main_program.global_block() op_path = _find_op_path(main_program, loss, no_grad_set) + # 1. build recompute state rc_state = RecomputeState(main_block, op_path) + if not rc_state.is_recompute(): + return + # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.build_states() @@ -300,6 +307,7 @@ def _apply_single_impl(self, main_program, startup_program, context): rc_state.ops[idx2 - 1].output_arg_names, ) ) + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: From a498f9b47cdf3b5ab475c25b821cae29393582a4 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 13 Dec 2022 10:11:59 +0800 Subject: [PATCH 22/24] update unittest --- .../paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt | 2 +- .../tests/unittests/auto_parallel/test_tuning_recompute.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index d6ccfb8f6d57b..35143858c8abc 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -75,7 +75,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) - set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 160) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py index 5e925ee5c1bad..c9087f064701d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -36,7 +36,7 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=12, + num_hidden_layers=2, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", @@ -97,6 +97,8 @@ def test_recompute_pass(self): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) + assert not engine._dist_contexts['train'].strategy.recompute.enable + if __name__ == "__main__": unittest.main() From 036aedf7883367b4a01d3eee8335167daa93b7d4 Mon Sep 17 00:00:00 2001 From: zhaoyingli Date: Tue, 13 Dec 2022 15:25:11 +0800 Subject: [PATCH 23/24] rename set_recompute_segments --- python/paddle/distributed/auto_parallel/engine.py | 2 +- python/paddle/distributed/auto_parallel/utils.py | 2 +- .../tests/unittests/auto_parallel/CMakeLists.txt | 2 +- .../unittests/auto_parallel/test_tuning_recompute.py | 11 ++++++++--- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index fc49ea3c1aaaf..dc7470283aef8 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -609,7 +609,7 @@ def _build(self, mode): if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_scope( + auto_utils.set_recompute_segments( self._model, self._losses, self._strategy, serial_main_prog ) self._dist_contexts[mode] = DistributedContext( diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index fd3e97809c4bf..abc2be9ea9541 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1926,7 +1926,7 @@ def is_recompute_op(op): ) -def set_recompute_scope(model, losses, strategy, program): +def set_recompute_segments(model, losses, strategy, program): from ..passes.auto_parallel_recompute import RecomputeState if not losses: diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 35143858c8abc..21c0f88438ad8 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -75,7 +75,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) - set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 50) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py index c9087f064701d..ecf6179fb821d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -36,7 +36,7 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=2, + num_hidden_layers=14, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", @@ -70,7 +70,7 @@ def apply_pass(): tuning = strategy.tuning tuning.enable = True tuning.profile_start_step = 1 - tuning.profile_end_step = 5 + tuning.profile_end_step = 2 tuning.run_after_tuning = True tuning.verbose = True return strategy @@ -97,7 +97,12 @@ def test_recompute_pass(self): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) - assert not engine._dist_contexts['train'].strategy.recompute.enable + assert ( + engine._dist_contexts[ + 'train' + ].strategy.recompute.no_recompute_segments + > 0 + ) if __name__ == "__main__": From ebe2ffbd9487080e72b8cacd1870c629651da013 Mon Sep 17 00:00:00 2001 From: zhaoyinglia Date: Wed, 14 Dec 2022 10:06:08 +0800 Subject: [PATCH 24/24] fix unittest --- .../unittests/auto_parallel/test_tuning_recompute.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py index ecf6179fb821d..a2a7deee6d216 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -98,9 +98,11 @@ def test_recompute_pass(self): engine._tune(self.dataset, 3, batch_size=self.batch_size) assert ( - engine._dist_contexts[ - 'train' - ].strategy.recompute.no_recompute_segments + len( + engine._dist_contexts[ + 'train' + ].strategy.recompute.no_recompute_segments + ) > 0 )