-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement restart feature #268
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
import glob | ||
import os | ||
import uuid | ||
import dill as pickle | ||
|
||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional, Sequence, Union | ||
|
@@ -113,6 +118,7 @@ def __init__(self, | |
requirements: Optional[OptimizationParameters] = None, | ||
graph_generation_params: Optional[GraphGenerationParams] = None, | ||
graph_optimizer_params: Optional[AlgorithmParameters] = None, | ||
saved_state_path='saved_optimisation_state/main', | ||
**custom_optimizer_params): | ||
self.log = default_log(self) | ||
self._objective = objective | ||
|
@@ -128,6 +134,9 @@ def __init__(self, | |
# Log random state for reproducibility of runs | ||
RandomStateHandler.log_random_state() | ||
|
||
self._saved_state_path = saved_state_path | ||
self._run_id = str(uuid.uuid1()) | ||
|
||
@property | ||
def objective(self) -> Objective: | ||
"""Returns Objective of this optimizer with information about metrics used.""" | ||
|
@@ -161,13 +170,44 @@ def set_evaluation_callback(self, callback: Optional[GraphFunction]): | |
@property | ||
def _progressbar(self): | ||
if self.requirements.show_progress: | ||
bar = tqdm(total=self.requirements.num_of_generations, desc='Generations', unit='gen', initial=0) | ||
if self.use_saved_state: | ||
bar = tqdm(total=self.requirements.num_of_generations, desc='Generations', unit='gen', | ||
initial=self.current_generation_num - 2) | ||
else: | ||
bar = tqdm(total=self.requirements.num_of_generations, desc='Generations', unit='gen', initial=0) | ||
Comment on lines
+173
to
+177
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. У этого класса нет атрибута
|
||
else: | ||
# disable call to tqdm.__init__ to avoid stdout/stderr access inside it | ||
# part of a workaround for https://github.com/nccr-itmo/FEDOT/issues/765 | ||
bar = EmptyProgressBar() | ||
return bar | ||
|
||
def save(self, saved_state_path): | ||
""" | ||
Method for serializing and saving a class object to a file using the dill library | ||
:param str saved_state_path: full path to the saved state file (including filename) | ||
""" | ||
folder_path = os.path.dirname(os.path.abspath(saved_state_path)) | ||
if not os.path.isdir(folder_path): | ||
os.makedirs(folder_path) | ||
self.log.info(f'Created directory for saving optimization state: {folder_path}') | ||
with open(saved_state_path, 'wb') as f: | ||
pickle.dump(self.__dict__, f, 2) | ||
|
||
def load(self, saved_state_path): | ||
""" | ||
Method for loading a serialized class object from file using the dill library | ||
:param str saved_state_path: full path to the saved state file | ||
""" | ||
with open(saved_state_path, 'rb') as f: | ||
self.__dict__.update(pickle.load(f)) | ||
|
||
def _find_latest_dir(self, directory: str) -> str: | ||
return max([os.path.join(directory, d) for d in os.listdir(directory) if os.path.isdir( | ||
os.path.join(directory, d))], key=os.path.getmtime) | ||
|
||
def _find_latest_file_in_dir(self, directory: str) -> str: | ||
return max(glob.glob(os.path.join(directory, '*')), key=os.path.getmtime) | ||
|
||
|
||
IterationCallback = Callable[[PopulationT, GraphOptimizer], Any] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
import os | ||
import time | ||
from abc import abstractmethod | ||
from datetime import timedelta, datetime | ||
from random import choice | ||
from typing import Any, Optional, Sequence, Dict | ||
|
||
|
@@ -13,6 +16,7 @@ | |
from golem.core.optimisers.optimization_parameters import GraphRequirements | ||
from golem.core.optimisers.optimizer import GraphGenerationParams, GraphOptimizer, AlgorithmParameters | ||
from golem.core.optimisers.timer import OptimisationTimer | ||
from golem.core.paths import default_data_dir | ||
from golem.utilities.grouped_condition import GroupedCondition | ||
|
||
|
||
|
@@ -40,25 +44,76 @@ def __init__(self, | |
requirements: GraphRequirements, | ||
graph_generation_params: GraphGenerationParams, | ||
graph_optimizer_params: Optional['AlgorithmParameters'] = None, | ||
use_saved_state: bool = False, | ||
saved_state_path: str = 'saved_optimisation_state/main/populational_optimiser', | ||
saved_state_file: str = None, | ||
**custom_optimizer_params | ||
): | ||
super().__init__(objective, initial_graphs, requirements, | ||
graph_generation_params, graph_optimizer_params, **custom_optimizer_params) | ||
self.population = None | ||
self.generations = GenerationKeeper(self.objective, keep_n_best=requirements.keep_n_best) | ||
self.timer = OptimisationTimer(timeout=self.requirements.timeout) | ||
|
||
dispatcher_type = MultiprocessingDispatcher if self.requirements.parallelization_mode == 'populational' else \ | ||
SequentialDispatcher | ||
|
||
self.eval_dispatcher = dispatcher_type(adapter=graph_generation_params.adapter, | ||
n_jobs=requirements.n_jobs, | ||
graph_cleanup_fn=_try_unfit_graph, | ||
delegate_evaluator=graph_generation_params.remote_evaluator) | ||
super().__init__(objective, initial_graphs, requirements, graph_generation_params, graph_optimizer_params, | ||
saved_state_path, **custom_optimizer_params) | ||
|
||
# Restore state from previous run | ||
if use_saved_state: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. лучше вынести всю restore optimisation related логику в отдельный приватный метод |
||
self.log.info('USING SAVED STATE') | ||
if saved_state_file: | ||
if os.path.isfile(saved_state_file): | ||
current_saved_state_path = saved_state_file | ||
else: | ||
raise SystemExit('ERROR: Could not restore saved optimisation state: ' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. можно наверно просто писать лог, мол начать с сохраненного состояния не удалось, оптимизация начинается с нуля |
||
f'given file with saved state {saved_state_file} not found.') | ||
else: | ||
try: | ||
full_state_path = os.path.join(default_data_dir(), self._saved_state_path) | ||
current_saved_state_path = self._find_latest_file_in_dir(self._find_latest_dir(full_state_path)) | ||
except (ValueError, FileNotFoundError): | ||
raise SystemExit('ERROR: Could not restore saved optimisation state: ' | ||
f'path with saved state {full_state_path} not found.') | ||
try: | ||
self.load(current_saved_state_path) | ||
except Exception as e: | ||
raise SystemExit('ERROR: Could not restore saved optimisation state from {full_state_path}.' | ||
f'If saved state file is broken remove it manually from the saved state dir or' | ||
f'pass a valid saved state filepath.' | ||
f'Full error message: {e}') | ||
|
||
# Override optimisation params from the saved state file with new values | ||
self.requirements.num_of_generations = requirements.num_of_generations | ||
self.requirements.timeout = requirements.timeout | ||
|
||
# Update all time parameters | ||
saved_state_timestamp = datetime.fromtimestamp(os.path.getmtime(current_saved_state_path)) | ||
elapsed_time: timedelta = saved_state_timestamp - self.timer.start_time | ||
|
||
timeout = self.requirements.timeout - elapsed_time | ||
self.timer = OptimisationTimer(timeout=timeout) | ||
self.requirements.timeout = self.requirements.timeout - timedelta(seconds=elapsed_time.total_seconds()) | ||
self.eval_dispatcher.timer = self.requirements.timeout | ||
|
||
stag_time_delta = saved_state_timestamp - self.generations._stagnation_start_time | ||
self.generations._stagnation_start_time = datetime.now() - stag_time_delta | ||
else: | ||
self.population = None | ||
self.generations = GenerationKeeper(self.objective, keep_n_best=requirements.keep_n_best) | ||
self.timer = OptimisationTimer(timeout=self.requirements.timeout) | ||
|
||
dispatcher_type = MultiprocessingDispatcher if self.requirements.parallelization_mode == 'populational' \ | ||
else SequentialDispatcher | ||
|
||
self.eval_dispatcher = dispatcher_type(adapter=graph_generation_params.adapter, | ||
n_jobs=requirements.n_jobs, | ||
graph_cleanup_fn=_try_unfit_graph, | ||
delegate_evaluator=graph_generation_params.remote_evaluator) | ||
|
||
# in how many generations structural diversity check should be performed | ||
self.gen_structural_diversity_check = self.graph_optimizer_params.structural_diversity_frequency_check | ||
|
||
self.use_saved_state = use_saved_state | ||
|
||
# early_stopping_iterations and early_stopping_timeout may be None, so use some obvious max number | ||
max_stagnation_length = requirements.early_stopping_iterations or requirements.num_of_generations | ||
max_stagnation_time = requirements.early_stopping_timeout or self.timer.timeout | ||
|
||
self.stop_optimization = \ | ||
GroupedCondition(results_as_message=True).add_condition( | ||
lambda: self.timer.is_time_limit_reached(self.current_generation_num - 1), | ||
|
@@ -70,10 +125,10 @@ def __init__(self, | |
).add_condition( | ||
lambda: (max_stagnation_length is not None and | ||
self.generations.stagnation_iter_count >= max_stagnation_length), | ||
'Optimisation finished: Early stopping iterations criteria was satisfied' | ||
'Optimisation finished: Early stopping iterations criteria was satisfied (stagnation_iter_count)' | ||
).add_condition( | ||
lambda: self.generations.stagnation_time_duration >= max_stagnation_time, | ||
'Optimisation finished: Early stopping timeout criteria was satisfied' | ||
'Optimisation finished: Early stopping timeout criteria was satisfied (stagnation_time_duration)' | ||
) | ||
# in how many generations structural diversity check should be performed | ||
self.gen_structural_diversity_check = self.graph_optimizer_params.structural_diversity_frequency_check | ||
|
@@ -86,14 +141,16 @@ def set_evaluation_callback(self, callback: Optional[GraphFunction]): | |
# Redirect callback to evaluation dispatcher | ||
self.eval_dispatcher.set_graph_evaluation_callback(callback) | ||
|
||
def optimise(self, objective: ObjectiveFunction) -> Sequence[Graph]: | ||
def optimise(self, objective: ObjectiveFunction, save_state_delta: int = 60) -> Sequence[Graph]: | ||
|
||
# eval_dispatcher defines how to evaluate objective on the whole population | ||
saved_state_path = os.path.join(default_data_dir(), self._saved_state_path, self._run_id) | ||
evaluator = self.eval_dispatcher.dispatch(objective, self.timer) | ||
last_write_time = datetime.now() | ||
|
||
with self.timer, self._progressbar as pbar: | ||
|
||
self._initial_population(evaluator) | ||
if not self.use_saved_state: | ||
self._initial_population(evaluator) | ||
|
||
while not self.stop_optimization(): | ||
try: | ||
|
@@ -108,7 +165,17 @@ def optimise(self, objective: ObjectiveFunction) -> Sequence[Graph]: | |
break | ||
# Adding of new population to history | ||
self._update_population(new_population) | ||
delta = datetime.now() - last_write_time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. отдельный метод, код будет более читаемым |
||
# Create new file with saved state every {save_state_delta} seconds | ||
if delta.seconds >= save_state_delta: | ||
save_path = os.path.join(saved_state_path, f'{str(round(time.time()))}.pkl') | ||
self.save(save_path) | ||
self.log.info(f'State saved to {save_path}') | ||
last_write_time = datetime.now() | ||
pbar.close() | ||
save_path = os.path.join(saved_state_path, f'{str(round(time.time()))}.pkl') | ||
self.save(save_path) | ||
self.log.info(save_path) | ||
self._update_population(self.best_individuals, 'final_choices') | ||
return [ind.graph for ind in self.best_individuals] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'saved_optimisation_state/main'
повторяется, можно вынести в отдельную константу. и вообще все строки