From fac646451f1accf0485b1ae16bcacab832f62b8f Mon Sep 17 00:00:00 2001 From: David Akhmedov <113194669+Parzival-05@users.noreply.github.com> Date: Sun, 15 Sep 2024 09:47:59 +0300 Subject: [PATCH] Gameserver update: add processing of disappeared processes & update statistics (#78) * Add handling in case the process disappears * Update statistics collector: use it only for one epoch * Add synchronization to the `fail` method * Restructure & add processing of process killing by system * Add `FailedMaps` instead of `Status` & refactor * Add info about failed maps to stats table * Fix logging of exception * `GameError`: Replace error with error_name & add error's description to `GameInterruptedError` * Fix types * Remove statscollector from GameError * Save 'FunctionTimedOut'ed maps * Remove extra exception handler. * Move `avg_by_attr` from StatisticsCollector * Fix calculation of average coverage --------- Co-authored-by: Anya497 --- AIAgent/common/errors.py | 18 -- AIAgent/common/utils.py | 15 ++ .../connection/broker_conn/socket_manager.py | 23 ++- AIAgent/connection/errors_connection.py | 22 +++ .../connection/game_server_conn/connector.py | 28 +++- AIAgent/launch_servers.py | 11 +- AIAgent/ml/game/errors_game.py | 22 +++ AIAgent/ml/{ => game}/play_game.py | 29 +++- AIAgent/ml/training/epochs_statistics.py | 154 +++++++++++------- AIAgent/ml/training/validation.py | 21 +-- AIAgent/run_training.py | 51 +++--- 11 files changed, 253 insertions(+), 141 deletions(-) delete mode 100644 AIAgent/common/errors.py create mode 100644 AIAgent/common/utils.py create mode 100644 AIAgent/connection/errors_connection.py create mode 100644 AIAgent/ml/game/errors_game.py rename AIAgent/ml/{ => game}/play_game.py (84%) diff --git a/AIAgent/common/errors.py b/AIAgent/common/errors.py deleted file mode 100644 index e0fdde52..00000000 --- a/AIAgent/common/errors.py +++ /dev/null @@ -1,18 +0,0 @@ -from common.game import GameMap, GameMap2SVM - - -class GameErrors(ExceptionGroup): - def __new__(cls, errors: list[Exception], maps: list[GameMap]): - self = super().__new__(GameErrors, "There are failed or timeouted maps", errors) - self.maps = maps - return self - - def derive(self, excs): - return GameErrors(self.message, excs) - - -class GameError(Exception): - def __init__(self, game_map2svm: GameMap2SVM, error: Exception) -> None: - self._map = game_map2svm - self._error = error - super().__init__(self._error.args) diff --git a/AIAgent/common/utils.py b/AIAgent/common/utils.py new file mode 100644 index 00000000..8da88331 --- /dev/null +++ b/AIAgent/common/utils.py @@ -0,0 +1,15 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def inheritors(cls: T) -> set[T]: + subclasses: set[T] = set() + work = [cls] + while work: + parent = work.pop() + for child in parent.__subclasses__(): + if child not in subclasses: + subclasses.add(child) + work.append(child) + return subclasses diff --git a/AIAgent/connection/broker_conn/socket_manager.py b/AIAgent/connection/broker_conn/socket_manager.py index 4c1f228e..4937532a 100644 --- a/AIAgent/connection/broker_conn/socket_manager.py +++ b/AIAgent/connection/broker_conn/socket_manager.py @@ -2,10 +2,19 @@ import time from contextlib import contextmanager, suppress +import psutil import websocket from config import GameServerConnectorConfig from connection.broker_conn.classes import ServerInstanceInfo, SVMInfo from connection.broker_conn.requests import acquire_instance, return_instance +from connection.errors_connection import ProcessStoppedError + + +@contextmanager +def process_running(pid): + if not psutil.pid_exists(pid): + raise ProcessStoppedError + yield def wait_for_connection(server_instance: ServerInstanceInfo): @@ -18,7 +27,7 @@ def wait_for_connection(server_instance: ServerInstanceInfo): ConnectionRefusedError, ConnectionResetError, websocket.WebSocketTimeoutException, - ): + ), process_running(server_instance.pid): ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT_SEC) ws.connect( server_instance.ws_url, @@ -39,12 +48,12 @@ def wait_for_connection(server_instance: ServerInstanceInfo): @contextmanager def game_server_socket_manager(svm_info: SVMInfo): server_instance = acquire_instance(svm_info) - - socket = wait_for_connection(server_instance) - try: - socket.settimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC) - yield socket + socket = wait_for_connection(server_instance) + try: + socket.settimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC) + yield socket + finally: + socket.close() finally: - socket.close() return_instance(server_instance) diff --git a/AIAgent/connection/errors_connection.py b/AIAgent/connection/errors_connection.py new file mode 100644 index 00000000..0c83b6ce --- /dev/null +++ b/AIAgent/connection/errors_connection.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + + +class GameInterruptedError(Exception, ABC): + """Game was unexpectedly interrupted due to external reasons""" + + @property + @abstractmethod + def desc(self): + pass + + +class ProcessStoppedError(GameInterruptedError): + """SVM's process unexpectedly stopped""" + + desc = "SVM's process unexpectedly stopped" + + +class ConnectionLostError(GameInterruptedError): + """Connection to SVM was lost""" + + desc = "Connection to SVM was lost" diff --git a/AIAgent/connection/game_server_conn/connector.py b/AIAgent/connection/game_server_conn/connector.py index 44332a9f..2fc49596 100644 --- a/AIAgent/connection/game_server_conn/connector.py +++ b/AIAgent/connection/game_server_conn/connector.py @@ -1,8 +1,10 @@ +from functools import wraps import logging import logging.config from typing import Optional import websocket +from connection.errors_connection import ConnectionLostError from common.game import GameMap, GameState from .messages import ( @@ -54,12 +56,30 @@ def __init__( start_message = ClientMessage(StartMessageBody(**map.to_dict())) logging.debug(f"--> StartMessage : {start_message}") - self.ws.send(start_message.to_json()) + self.send(start_message.to_json()) self._current_step = 0 self.game_is_over = False self.map = map self.steps = steps + def catch_losing_of_connection(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ConnectionResetError as e: + raise ConnectionLostError from e + + return wrapper + + @catch_losing_of_connection + def receive(self): + return self.ws.recv() + + @catch_losing_of_connection + def send(self, msg): + return self.ws.send(msg) + def _raise_if_gameover(self, msg) -> GameOverServerMessage | str: if self.game_is_over: raise Connector.GameOver @@ -83,7 +103,7 @@ def _raise_if_gameover(self, msg) -> GameOverServerMessage | str: return msg def recv_state_or_throw_gameover(self) -> GameState: - received = self.ws.recv() + received = self.receive() data = GameStateServerMessage.from_json_handle( self._raise_if_gameover(received), expected=GameStateServerMessage, @@ -98,11 +118,11 @@ def send_step(self, next_state_id: int, predicted_usefullness: int): ) ) logging.debug(f"--> ClientMessage : {do_step_message}") - self.ws.send(do_step_message.to_json()) + self.send(do_step_message.to_json()) self._sent_state_id = next_state_id def recv_reward_or_throw_gameover(self) -> Reward: - received = self.ws.recv() + received = self.receive() decoded = RewardServerMessage.from_json_handle( self._raise_if_gameover(received), expected=RewardServerMessage, diff --git a/AIAgent/launch_servers.py b/AIAgent/launch_servers.py index afdc7d10..e855c919 100644 --- a/AIAgent/launch_servers.py +++ b/AIAgent/launch_servers.py @@ -175,10 +175,15 @@ async def run(): def kill_server(server_instance: ServerInstanceInfo): - os.kill(server_instance.pid, signal.SIGKILL) PROCS.remove(server_instance.pid) - - proc_info = psutil.Process(server_instance.pid) + try: + os.kill(server_instance.pid, signal.SIGKILL) + proc_info = psutil.Process(server_instance.pid) + except (ProcessLookupError, psutil.NoSuchProcess): + logging.warning( + f"Failed to kill the process with ID={server_instance.pid}: the process doesn't exist" + ) + return wait_for_reset_retries = FeatureConfig.ON_GAME_SERVER_RESTART.wait_for_reset_retries while wait_for_reset_retries: diff --git a/AIAgent/ml/game/errors_game.py b/AIAgent/ml/game/errors_game.py new file mode 100644 index 00000000..0178fd03 --- /dev/null +++ b/AIAgent/ml/game/errors_game.py @@ -0,0 +1,22 @@ +from func_timeout import FunctionTimedOut +from connection.errors_connection import GameInterruptedError +from common.utils import inheritors +from common.game import GameMap2SVM + + +class GameError(Exception): + + def __init__( + self, + game_map2svm: GameMap2SVM, + error_name: str, + ) -> None: + self._map = game_map2svm + self._error_name = error_name + + super().__init__(game_map2svm, error_name) + + def need_to_save_map(self): + gie_inheritors = inheritors(GameInterruptedError) + need_to_save_classes = list(gie_inheritors) + [FunctionTimedOut] + return self._error_name in map(lambda it: it.__name__, need_to_save_classes) diff --git a/AIAgent/ml/play_game.py b/AIAgent/ml/game/play_game.py similarity index 84% rename from AIAgent/ml/play_game.py rename to AIAgent/ml/game/play_game.py index 2afb876b..544b2154 100644 --- a/AIAgent/ml/play_game.py +++ b/AIAgent/ml/game/play_game.py @@ -3,15 +3,16 @@ import traceback from typing import TypeAlias -from common.errors import GameError from common.classes import GameResult, Map2Result from common.game import GameState, GameMap2SVM from config import FeatureConfig +from connection.errors_connection import GameInterruptedError from connection.broker_conn.socket_manager import game_server_socket_manager from connection.game_server_conn.connector import Connector from func_timeout import FunctionTimedOut, func_set_timeout from ml.protocols import Predictor from ml.training.dataset import Result, TrainingDataset, convert_input_to_tensor +from ml.game.errors_game import GameError TimeDuration: TypeAlias = float @@ -165,14 +166,28 @@ def play_game( ) map2result = Map2Result(game_map2svm, game_result) except (FunctionTimedOut, Exception) as error: + need_to_save = True + name_of_predictor = with_predictor.name() + if isinstance(error, FunctionTimedOut): - log_message = f"<{with_predictor.name()}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s" + log_message = f"<{name_of_predictor}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s" + elif isinstance(error, GameInterruptedError): + log_message = f"<{name_of_predictor}> failed on map {game_map2svm.GameMap.MapName} with {error.__class__.__name__}: {error.desc}" + need_to_save = False else: - log_message = f"<{with_predictor.name()}> failed on map {game_map2svm.GameMap.MapName}:\n{traceback.format_exc()}" + log_message = ( + f"<{name_of_predictor}> failed on map {game_map2svm.GameMap.MapName}:\n" + + "\n".join( + traceback.format_exception( + type(error), value=error, tb=error.__traceback__ + ) + ) + ) logging.warning(log_message) - FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.save_model( - with_predictor.model(), with_name=f"{with_predictor.name()}" - ) - raise GameError(game_map2svm, error) + if need_to_save: + FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.save_model( + with_predictor.model(), with_name=name_of_predictor + ) + raise GameError(game_map2svm=game_map2svm, error_name=error.__class__.__name__) return map2result diff --git a/AIAgent/ml/training/epochs_statistics.py b/AIAgent/ml/training/epochs_statistics.py index ef3fa297..0c0128bb 100644 --- a/AIAgent/ml/training/epochs_statistics.py +++ b/AIAgent/ml/training/epochs_statistics.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from functools import wraps +import multiprocessing from pathlib import Path -from statistics import mean from typing import TypeAlias import natsort +import numpy as np import pandas as pd from common.typealias import SVMName -from common.game import GameMap, GameMap2SVM +from common.game import GameMap2SVM from common.classes import Map2Result EpochNumber: TypeAlias = int @@ -17,22 +18,24 @@ def sort_dict(d): return dict(natsort.natsorted(d.items())) +def avg_by_attr(results, path_to_coverage: str) -> int: + coverage = np.average( + list(map(lambda result: getattr(result, path_to_coverage), results)) + ) + return coverage + + @dataclass class StatsWithTable: avg: float df: pd.DataFrame -class Status: - def __init__(self): - self.epoch: EpochNumber = 0 - self.failed_maps: list[GameMap2SVM] = [] - self.count_of_failed_maps: int = 0 +class FailedMaps(list[GameMap2SVM]): + lock = multiprocessing.Lock() def __str__(self) -> str: - result = ( - f"count of failed maps={self.count_of_failed_maps}, on epoch = {self.epoch}" - ) + result = f"count of failed maps = {len(self)}" return result @@ -42,10 +45,10 @@ def __init__( file: Path, ): self._file = file + self.lock = multiprocessing.Lock() - self._epochs_info: dict[EpochNumber, dict[SVMName, StatsWithTable]] = {} - self._epoch_number: EpochNumber = 0 - self._svms_status: dict[SVMName, Status] = {} + self._svms_stats_dict: dict[SVMName, StatsWithTable] = {} + self._failed_maps_dict: dict[SVMName, FailedMaps] = {} def update_file(func): @wraps(func) @@ -56,81 +59,108 @@ def wrapper(self, *args, **kwargs): return wrapper - @update_file - def fail(self, game_map2svm: GameMap2SVM): - svm_name = game_map2svm.SVMInfo.name - svm_status: Status = self._svms_status.get(svm_name, Status()) - svm_status.failed_maps.append(game_map2svm) - svm_status.count_of_failed_maps += 1 - self._svms_status[svm_name] = svm_status + def fail(self, game_map: GameMap2SVM): + svm_name = game_map.SVMInfo.name + with self.lock: + failed_maps = self._failed_maps_dict.setdefault(svm_name, FailedMaps()) + with failed_maps.lock: + failed_maps.append(game_map) def __clear_failed_maps(self): - for svm_status in self._svms_status.values(): - svm_status.failed_maps.clear() + for failed_maps in self._failed_maps_dict.values(): + failed_maps.clear() def get_failed_maps(self) -> list[GameMap2SVM]: """ - Returns failed maps on total epoch. + Returns failed maps. NB: The list of failed maps is cleared after each request. """ total_failed_maps: list[GameMap2SVM] = [] - for svm_status in self._svms_status.values(): - total_failed_maps.extend(svm_status.failed_maps) + for failed_maps in self._failed_maps_dict.values(): + total_failed_maps.extend(failed_maps) self.__clear_failed_maps() return total_failed_maps @update_file def update_results( self, - average_result: float, map2results_list: list[Map2Result], ): - epoch_info = self._epochs_info.get(self._epoch_number, {}) - svms_and_map2results_lists: dict[SVMName, list[Map2Result]] = dict() - for map2result in map2results_list: - svm_name = map2result.map.SVMInfo.name - map2results_list_of_svm = svms_and_map2results_lists.get(svm_name, []) - map2results_list_of_svm.append(map2result) - svms_and_map2results_lists[svm_name] = map2results_list_of_svm - for svm_name, map2results_list in svms_and_map2results_lists.items(): - epoch_info[svm_name] = StatsWithTable( - average_result, convert_to_df(map2results_list) - ) - self._epochs_info[self._epoch_number] = sort_dict(epoch_info) - - self._epoch_number += 1 - for svm_name in svms_and_map2results_lists: - svm_status = self._svms_status.get(svm_name, Status()) - svm_status.epoch = self._epoch_number - self._svms_status[svm_name] = svm_status - - def __get_epochs_results(self) -> str: - epochs_results = str() - for _, v in self._epochs_info.items(): - avgs = list(map(lambda statsWithTable: statsWithTable.avg, v.values())) - avg_common = mean(avgs) - epoch_results = list( - map(lambda statsWithTable: statsWithTable.df, v.values()) + def generate_svms_results_dict() -> dict[SVMName, list[Map2Result]]: + map2results_dict: dict[SVMName, list[Map2Result]] = dict() + for map2result in map2results_list: + svm_name = map2result.map.SVMInfo.name + map2results_list_of_svm = map2results_dict.get(svm_name, []) + map2results_list_of_svm.append(map2result) + map2results_dict[svm_name] = map2results_list_of_svm + return map2results_dict + + def generate_svms_stats_dict( + svms_and_map2results: dict[SVMName, list[Map2Result]] + ): + svms_stats_dict: dict[SVMName, list[StatsWithTable]] = dict() + for svm_name, map2results_list in svms_and_map2results.items(): + svms_stats_dict[svm_name] = StatsWithTable( + avg_by_attr( + list( + map( + lambda map2result: map2result.game_result, + map2results_list, + ) + ), + "actual_coverage_percent", + ), + convert_to_df(map2results_list), + ) + return svms_stats_dict + + def calc_avg(): + return avg_by_attr( + list( + map( + lambda map2result: map2result.game_result, + map2results_list, + ) + ), + "actual_coverage_percent", ) - df = pd.concat(epoch_results, axis=1) - epochs_results += f"Average coverage: {str(avg_common)}\n" - names_and_averages = zip(v.keys(), avgs) - epochs_results += "".join( + + svms_and_map2results_lists = generate_svms_results_dict() + svms_stats_dict = generate_svms_stats_dict(svms_and_map2results_lists) + + self._svms_stats_dict = sort_dict(svms_stats_dict) + self.avg_coverage = calc_avg() + + def __get_results(self) -> str: + svms_stats = self._svms_stats_dict.items() + _, svms_stats_with_table = list(zip(*svms_stats)) + + df_concat = pd.concat( + list( + map(lambda stats_with_table: stats_with_table.df, svms_stats_with_table) + ), + axis=1, + ) + + results = ( + f"Average coverage: {str(self.avg_coverage)}\n" + + "".join( list( map( - lambda pair: f"Average coverage of {pair[0]} = {pair[1]}\n", - names_and_averages, + lambda svm_name_and_stats_pair: f"Average coverage of {svm_name_and_stats_pair[0]} = {svm_name_and_stats_pair[1].avg}, {self._failed_maps_dict.get(svm_name_and_stats_pair[0], FailedMaps())}\n", + svms_stats, ) ) ) - epochs_results += df.to_markdown(tablefmt="psql") + "\n" - return epochs_results + + df_concat.to_markdown(tablefmt="psql") + ) + return results def __update_file(self): - epochs_results = self.__get_epochs_results() + results = self.__get_results() with open(self._file, "w") as f: - f.write(epochs_results) + f.write(results) def convert_to_df(map2result_list: list[Map2Result]) -> pd.DataFrame: diff --git a/AIAgent/ml/training/validation.py b/AIAgent/ml/training/validation.py index 156e1e5e..a8a3913e 100644 --- a/AIAgent/ml/training/validation.py +++ b/AIAgent/ml/training/validation.py @@ -6,14 +6,14 @@ import torch import tqdm import mlflow -from ml.training.epochs_statistics import StatisticsCollector from common.classes import Map2Result, GameMap2SVM -from common.errors import GameError from config import GeneralConfig +from ml.game.errors_game import GameError from ml.inference import infer -from ml.play_game import play_game +from ml.game.play_game import play_game from ml.training.dataset import TrainingDataset from ml.training.wrapper import TrainingModelWrapper +from ml.training.epochs_statistics import StatisticsCollector, avg_by_attr from torch_geometric.loader import DataLoader from paths import CURRENT_TABLE_PATH @@ -80,24 +80,21 @@ def validate_coverage( colour=progress_bar_colour, ): if isinstance(result, GameError): - statistics_collector.fail(result._map) + need_to_save_map: bool = result.need_to_save_map() + if not need_to_save_map: + statistics_collector.fail(result._map) else: all_results.append(result) - def avg_coverage(results, path_to_coverage: str) -> int: - coverage = np.average( - list(map(lambda result: getattr(result, path_to_coverage), results)) - ) - return coverage + statistics_collector.update_results(all_results) - average_result = avg_coverage( + average_result = avg_by_attr( list(map(lambda map2result: map2result.game_result, all_results)), "actual_coverage_percent", ) - statistics_collector.update_results(average_result, all_results) mlflow.log_metrics( { - "average_dataset_state_result": avg_coverage( + "average_dataset_state_result": avg_by_attr( dataset.maps_results.values(), "coverage_percent" ), "average_result": average_result, diff --git a/AIAgent/run_training.py b/AIAgent/run_training.py index d0992e23..39173f03 100644 --- a/AIAgent/run_training.py +++ b/AIAgent/run_training.py @@ -129,35 +129,30 @@ def model_init(**model_params) -> nn.Module: epochs=training_config.epochs, val_config=validation_config, ) - try: - if optuna_config.study_uri is None and weights_uri is None: - - def save_study(study, _): - joblib.dump( - study, - CURRENT_STUDY_PATH, - ) - with mlflow.start_run(mlflow.last_active_run().info.run_id): - mlflow.log_artifact(CURRENT_STUDY_PATH) - - study = optuna.create_study( - sampler=sampler, direction=optuna_config.study_direction - ) - study.optimize( - objective_partial, - n_trials=optuna_config.n_trials, - gc_after_trial=True, - n_jobs=optuna_config.n_jobs, - callbacks=[save_study], - ) - else: - downloaded_artifact_path = mlflow.artifacts.download_artifacts( - optuna_config.study_uri, dst_path=str(REPORT_PATH) + if optuna_config.study_uri is None and weights_uri is None: + def save_study(study, _): + joblib.dump( + study, + CURRENT_STUDY_PATH, ) - study: optuna.Study = joblib.load(downloaded_artifact_path) - objective_partial(study.best_trial) - except RuntimeError: - logging.error("Fail to train") + with mlflow.start_run(mlflow.last_active_run().info.run_id): + mlflow.log_artifact(CURRENT_STUDY_PATH) + study = optuna.create_study( + sampler=sampler, direction=optuna_config.study_direction + ) + study.optimize( + objective_partial, + n_trials=optuna_config.n_trials, + gc_after_trial=True, + n_jobs=optuna_config.n_jobs, + callbacks=[save_study], + ) + else: + downloaded_artifact_path = mlflow.artifacts.download_artifacts( + optuna_config.study_uri, dst_path=str(REPORT_PATH) + ) + study: optuna.Study = joblib.load(downloaded_artifact_path) + objective_partial(study.best_trial) def objective(