Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] recompute tuning #48608

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_field_default_config(category, field, default_value):
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False)

Expand Down Expand Up @@ -113,12 +113,10 @@ def set_field_default_config(category, field, default_value):
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
set_field_default_config(TUNING, "debug", False)

#########################################
# dataset configuration
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,9 @@ def _build(self, mode):
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)

auto_utils.set_recompute_ckpts(self._model, self._strategy)
auto_utils.set_recompute_segments(
self._model, self._losses, self._strategy, serial_main_prog
)
self._dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
Expand Down Expand Up @@ -649,7 +651,6 @@ def _optimization_tuning(self, mode, dataset, batch_size):
from .tuner.optimization_tuner import OptimizationTuner

self._optimization_tuner = OptimizationTuner(
self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self._inputs_spec,
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __deepcopy__(self, memo):
setattr(result, k, copy.deepcopy(v, memo))
return result

def get(self, k, d=None):
result_dict = self.to_dict()
return result_dict.get(k, d)


class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
Expand Down
95 changes: 92 additions & 3 deletions python/paddle/distributed/auto_parallel/tuner/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from abc import ABC, abstractmethod

from ..utils import get_logger
from ..utils import get_logger, is_recompute_op
from .trial import OptimizationTunerTrial as Trial
from .trial import TrialStatus

Expand Down Expand Up @@ -54,7 +54,7 @@ def changed_configs(self):
def collect_model_info(self, main_prog, startup_prog):
"""
Collect the model static info (from programs) that could be used to
pruning candidate trials and saving tuning time.For instance,
pruning candidate trials and saving tuning time. For instance,
model info like number of model parameters and activation memory could be
used to prune candidated trial and decide the next trial.
"""
Expand Down Expand Up @@ -116,7 +116,7 @@ def _init_spaces(self):
self._max_stage = 3
self._trial_idx = 0

stage_range = self._config.sharding.to_dict().get("tuning_range", None)
stage_range = self._config.sharding.get("tuning_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
Expand Down Expand Up @@ -157,3 +157,92 @@ def update(self, results):
)
else:
self._trial_idx += 1


@register_algor("recompute")
class ReccomputeCheckpointAlgorithm(AlgorithmBase):
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["recompute"]

def collect_model_info(self, main_prog, startup_prog):
segments = []
for op in main_prog.global_block().ops:
if not is_recompute_op(op):
continue

seg_name = op.attr('op_namescope')
if seg_name not in segments:
segments.append(seg_name)

self._total_num_trial = len(segments)
self._tuning_segments = list(range(len(segments)))
self._trail_left = 0
self._trail_right = len(segments) - 1
self._trial_idx = int(0 + (len(segments) - 1) / 2)

def _init_spaces(self):
self._recompute_mode = "all"

def next_trial(self):
if self._trial_idx < self._total_num_trial:
if self._recompute_mode == "all":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
name = "trial-recompute-all-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "none":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.enable = False
name = "trial-recompute-none-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "part":
new_no_recompute = self._tuning_segments[: self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.no_recompute_segments.extend(new_no_recompute)
name = "trial-recompute-part-segments-idx{}".format(
self._trial_idx
)
return Trial(new_strategy, name, self.changed_configs)
else:
return Trial(None, None, None, status=TrialStatus.STOPPED)

def update(self, results):

et = results.get("ErrorType", None)
if self._recompute_mode == "all":
if et and et == "ResourceExhaustedError":
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute all candidate segments is failed with OOM, please reduce model size or batch size."
)
else:
self._recompute_mode = "none"
elif self._recompute_mode == "none":
if et and et == "ResourceExhaustedError":
self._recompute_mode = "part"
else:
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute is unnecessary for this model size, which will reduce the Throughtput."
)
else:
if self._trail_left >= self._trail_right:
self._trial_idx = self._total_num_trial
elif et and et == "ResourceExhaustedError":
self._trail_left = self._trail_left
self._trail_right = self._trial_idx - 1
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
else:
self._trail_left = self._trial_idx + 1
self._trail_right = self._trail_right
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
50 changes: 21 additions & 29 deletions python/paddle/distributed/auto_parallel/tuner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ class TuningConfig:
tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
"""

def __init__(self, user_config, strategy):
def __init__(self, strategy):

if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.")

if not user_config:
user_config = {}

self._tuning_passes_name = set()
self._dist_strategy = copy.deepcopy(strategy)
self._mode = None
Expand All @@ -48,9 +45,9 @@ def __init__(self, user_config, strategy):
self._project_dir = None
self._max_num_trial = None
self._early_stop = None
self._verbose = None
self._debug = None

self._initialize(user_config)
self._initialize()

@property
def mode(self):
Expand Down Expand Up @@ -81,29 +78,25 @@ def early_stop(self):
return self._early_stop

@property
def verbose(self):
return self._verbose
def debug(self):
return self._debug

@property
def dist_strategy(self):
return self._dist_strategy

# initialize config with user define value or default value
def _initialize(self, user_config):

self._mode = user_config.get("mode", "PROFILE")

self._profile_start_step = user_config.get("profile_start_step", 10)

self._profile_end_step = user_config.get("profile_end_step", 30)

self._max_num_trial = user_config.get("max_num_trial", 50)

self._early_stop = user_config.get("early_stop", None)
def _initialize(self):
tuning_strategy = self._dist_strategy.tuning

self._verbose = user_config.get("verbose", False)
self._mode = tuning_strategy.get("mode", "PROFILE")
self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
self._early_stop = tuning_strategy.get("early_stop", None)
self._debug = tuning_strategy.get("debug", False)

project_dir = user_config.get("project_dir", None)
project_dir = tuning_strategy.get("project_dir", None)
if not project_dir:
project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
self._project_dir = project_dir
Expand All @@ -116,15 +109,14 @@ def _initialize(self, user_config):
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)

config_name = p
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
p_strategy = getattr(self._dist_strategy, p)
self.__dict__[p] = p_strategy

# TODO verify the user defined configs
user_config_for_pass = user_config.get(p, None)
if user_config_for_pass:
for k, v in user_config_for_pass.items():
self.__dict__[config_name][k] = v
# # TODO verify the user defined configs
# tuning_config_for_pass = tuning_strategy.get(p, None)
# if tuning_config_for_pass:
# for k, v in tuning_config_for_pass.items():
# self.__dict__[p][k] = v

# (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
def __getattr__(self, item):
Expand Down
Loading