diff --git a/AIAgent/common/errors.py b/AIAgent/common/errors.py index e6214b1..e0fdde5 100644 --- a/AIAgent/common/errors.py +++ b/AIAgent/common/errors.py @@ -1,4 +1,4 @@ -from common.game import GameMap +from common.game import GameMap, GameMap2SVM class GameErrors(ExceptionGroup): @@ -9,3 +9,10 @@ def __new__(cls, errors: list[Exception], maps: list[GameMap]): def derive(self, excs): return GameErrors(self.message, excs) + + +class GameError(Exception): + def __init__(self, game_map2svm: GameMap2SVM, error: Exception) -> None: + self._map = game_map2svm + self._error = error + super().__init__(self._error.args) diff --git a/AIAgent/connection/broker_conn/requests.py b/AIAgent/connection/broker_conn/requests.py index a5aaab0..13af198 100644 --- a/AIAgent/connection/broker_conn/requests.py +++ b/AIAgent/connection/broker_conn/requests.py @@ -12,7 +12,9 @@ def acquire_instance(svm_info: SVMInfo) -> ServerInstanceInfo: WebsocketSourceLinks.GET_WS + "?" + urlencode(SVMInfo.to_dict(svm_info)) ) if response.status != 200: - logging.error(f"{response.status} with {content=} on acquire_instance call") + logging.error( + f"{response.status} with {content=} on acquire_instance call for {svm_info}" + ) raise RuntimeError(f"Not ok response: {response}, {content}") acquired_instance = ServerInstanceInfo.from_json( json.loads(content.decode("utf-8")) diff --git a/AIAgent/ml/play_game.py b/AIAgent/ml/play_game.py index f1e9e84..2afb876 100644 --- a/AIAgent/ml/play_game.py +++ b/AIAgent/ml/play_game.py @@ -3,9 +3,9 @@ import traceback from typing import TypeAlias -from common.errors import GameErrors +from common.errors import GameError from common.classes import GameResult, Map2Result -from common.game import GameMap, GameState, GameMap2SVM +from common.game import GameState, GameMap2SVM from config import FeatureConfig from connection.broker_conn.socket_manager import game_server_socket_manager from connection.game_server_conn.connector import Connector @@ -142,43 +142,37 @@ def play_map_with_timeout( def play_game( with_predictor: Predictor, max_steps: int, - maps: list[GameMap2SVM], + game_map2svm: GameMap2SVM, with_dataset: TrainingDataset, ): - list_of_map2result: list[Map2Result] = [] - list_of_failed_maps_with_errors: list[tuple[GameMap, Exception]] = [] - for game_map2svm in maps: - logging.info(f"<{with_predictor.name()}> is playing {game_map2svm.GameMap.MapName}") - try: - play_func = ( - play_map_with_timeout - if FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.enabled - else play_map - ) - with game_server_socket_manager(game_map2svm.SVMInfo) as ws: - game_result, time = play_func( - with_connector=Connector(ws, game_map2svm.GameMap, max_steps), - with_predictor=with_predictor, - with_dataset=with_dataset, - ) - logging.info( - f"<{with_predictor.name()}> finished map {game_map2svm.GameMap.MapName} " - f"in {game_result.steps_count} steps, {time} seconds, " - f"actual coverage: {game_result.actual_coverage_percent:.2f}" + logging.info(f"<{with_predictor.name()}> is playing {game_map2svm.GameMap.MapName}") + try: + play_func = ( + play_map_with_timeout + if FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.enabled + else play_map + ) + with game_server_socket_manager(game_map2svm.SVMInfo) as ws: + game_result, time = play_func( + with_connector=Connector(ws, game_map2svm.GameMap, max_steps), + with_predictor=with_predictor, + with_dataset=with_dataset, ) - list_of_map2result.append(Map2Result(game_map2svm.GameMap, game_result)) - except (FunctionTimedOut, Exception) as error: - if isinstance(error, FunctionTimedOut): - log_message = f"<{with_predictor.name()}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s" - else: - log_message = f"<{with_predictor.name()}> failed on map {game_map2svm.GameMap.MapName}:\n{traceback.format_exc()}" - logging.warning(log_message) - list_of_failed_maps_with_errors.append((game_map2svm.GameMap, error)) - if list_of_failed_maps_with_errors: + logging.info( + f"<{with_predictor.name()}> finished map {game_map2svm.GameMap.MapName} " + f"in {game_result.steps_count} steps, {time} seconds, " + f"actual coverage: {game_result.actual_coverage_percent:.2f}" + ) + map2result = Map2Result(game_map2svm, game_result) + except (FunctionTimedOut, Exception) as error: + if isinstance(error, FunctionTimedOut): + log_message = f"<{with_predictor.name()}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s" + else: + log_message = f"<{with_predictor.name()}> failed on map {game_map2svm.GameMap.MapName}:\n{traceback.format_exc()}" + logging.warning(log_message) FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.save_model( with_predictor.model(), with_name=f"{with_predictor.name()}" ) - failed_maps, errors = list(zip(*list_of_failed_maps_with_errors)) - raise GameErrors(errors, failed_maps) + raise GameError(game_map2svm, error) - return list_of_map2result + return map2result diff --git a/AIAgent/ml/training/epochs_statistics.py b/AIAgent/ml/training/epochs_statistics.py index 8c2d4b9..ef3fa29 100644 --- a/AIAgent/ml/training/epochs_statistics.py +++ b/AIAgent/ml/training/epochs_statistics.py @@ -57,13 +57,12 @@ def wrapper(self, *args, **kwargs): return wrapper @update_file - def fail(self, game_maps: list[GameMap2SVM]): - for game_map2svm in game_maps: - svm_name = game_map2svm.SVMInfo.name - svm_status: Status = self._svms_status.get(svm_name, Status()) - svm_status.failed_maps.append(game_map2svm) - svm_status.count_of_failed_maps += 1 - self._svms_status[svm_name] = svm_status + def fail(self, game_map2svm: GameMap2SVM): + svm_name = game_map2svm.SVMInfo.name + svm_status: Status = self._svms_status.get(svm_name, Status()) + svm_status.failed_maps.append(game_map2svm) + svm_status.count_of_failed_maps += 1 + self._svms_status[svm_name] = svm_status def __clear_failed_maps(self): for svm_status in self._svms_status.values(): diff --git a/AIAgent/ml/training/validation.py b/AIAgent/ml/training/validation.py index c107e6f..156e1e5 100644 --- a/AIAgent/ml/training/validation.py +++ b/AIAgent/ml/training/validation.py @@ -7,8 +7,8 @@ import tqdm import mlflow from ml.training.epochs_statistics import StatisticsCollector -from common.classes import Map2Result -from common.errors import GameErrors +from common.classes import Map2Result, GameMap2SVM +from common.errors import GameError from config import GeneralConfig from ml.inference import infer from ml.play_game import play_game @@ -30,13 +30,14 @@ def wrapper(*args, **kwargs): @catch_return_exception -def play_game_task(task): - maps, dataset, wrapper = task - +def play_game_task( + task: tuple[GameMap2SVM, TrainingDataset, TrainingModelWrapper], +) -> Map2Result: + game_map2svm, dataset, wrapper = task result = play_game( with_predictor=wrapper, max_steps=GeneralConfig.MAX_STEPS, - maps=maps, + game_map2svm=game_map2svm, with_dataset=dataset, ) torch.cuda.empty_cache() @@ -67,7 +68,7 @@ def validate_coverage( Your favorite colour for progress bar. """ wrapper = TrainingModelWrapper(model) - tasks = [([game_map2svm], dataset, wrapper) for game_map2svm in dataset.maps] + tasks = [(game_map2svm, dataset, wrapper) for game_map2svm in dataset.maps] statistics_collector = StatisticsCollector(CURRENT_TABLE_PATH) with mp.Pool(server_count) as p: all_results: list[Map2Result] = list() @@ -78,10 +79,10 @@ def validate_coverage( ncols=100, colour=progress_bar_colour, ): - if isinstance(result, GameErrors): - statistics_collector.fail(result.maps) + if isinstance(result, GameError): + statistics_collector.fail(result._map) else: - all_results.extend(result) + all_results.append(result) def avg_coverage(results, path_to_coverage: str) -> int: coverage = np.average( @@ -90,7 +91,7 @@ def avg_coverage(results, path_to_coverage: str) -> int: return coverage average_result = avg_coverage( - list(map(lambda map2result: map2result.game_result, result)), + list(map(lambda map2result: map2result.game_result, all_results)), "actual_coverage_percent", ) statistics_collector.update_results(average_result, all_results)