Skip to content

Commit

Permalink
Play game with one map instead of list of one map. Clarify logs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Anya497 committed Jul 24, 2024
1 parent 5eb2d55 commit 6ea4758
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 55 deletions.
9 changes: 8 additions & 1 deletion AIAgent/common/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from common.game import GameMap
from common.game import GameMap, GameMap2SVM


class GameErrors(ExceptionGroup):
Expand All @@ -9,3 +9,10 @@ def __new__(cls, errors: list[Exception], maps: list[GameMap]):

def derive(self, excs):
return GameErrors(self.message, excs)


class GameError(Exception):
def __init__(self, game_map2svm: GameMap2SVM, error: Exception) -> None:
self._map = game_map2svm
self._error = error
super().__init__(self._error.args)
4 changes: 3 additions & 1 deletion AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def acquire_instance(svm_info: SVMInfo) -> ServerInstanceInfo:
WebsocketSourceLinks.GET_WS + "?" + urlencode(SVMInfo.to_dict(svm_info))
)
if response.status != 200:
logging.error(f"{response.status} with {content=} on acquire_instance call")
logging.error(
f"{response.status} with {content=} on acquire_instance call for {svm_info}"
)
raise RuntimeError(f"Not ok response: {response}, {content}")
acquired_instance = ServerInstanceInfo.from_json(
json.loads(content.decode("utf-8"))
Expand Down
64 changes: 29 additions & 35 deletions AIAgent/ml/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import traceback
from typing import TypeAlias

from common.errors import GameErrors
from common.errors import GameError
from common.classes import GameResult, Map2Result
from common.game import GameMap, GameState, GameMap2SVM
from common.game import GameState, GameMap2SVM
from config import FeatureConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from connection.game_server_conn.connector import Connector
Expand Down Expand Up @@ -142,43 +142,37 @@ def play_map_with_timeout(
def play_game(
with_predictor: Predictor,
max_steps: int,
maps: list[GameMap2SVM],
game_map2svm: GameMap2SVM,
with_dataset: TrainingDataset,
):
list_of_map2result: list[Map2Result] = []
list_of_failed_maps_with_errors: list[tuple[GameMap, Exception]] = []
for game_map2svm in maps:
logging.info(f"<{with_predictor.name()}> is playing {game_map2svm.GameMap.MapName}")
try:
play_func = (
play_map_with_timeout
if FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.enabled
else play_map
)
with game_server_socket_manager(game_map2svm.SVMInfo) as ws:
game_result, time = play_func(
with_connector=Connector(ws, game_map2svm.GameMap, max_steps),
with_predictor=with_predictor,
with_dataset=with_dataset,
)
logging.info(
f"<{with_predictor.name()}> finished map {game_map2svm.GameMap.MapName} "
f"in {game_result.steps_count} steps, {time} seconds, "
f"actual coverage: {game_result.actual_coverage_percent:.2f}"
logging.info(f"<{with_predictor.name()}> is playing {game_map2svm.GameMap.MapName}")
try:
play_func = (
play_map_with_timeout
if FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.enabled
else play_map
)
with game_server_socket_manager(game_map2svm.SVMInfo) as ws:
game_result, time = play_func(
with_connector=Connector(ws, game_map2svm.GameMap, max_steps),
with_predictor=with_predictor,
with_dataset=with_dataset,
)
list_of_map2result.append(Map2Result(game_map2svm.GameMap, game_result))
except (FunctionTimedOut, Exception) as error:
if isinstance(error, FunctionTimedOut):
log_message = f"<{with_predictor.name()}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s"
else:
log_message = f"<{with_predictor.name()}> failed on map {game_map2svm.GameMap.MapName}:\n{traceback.format_exc()}"
logging.warning(log_message)
list_of_failed_maps_with_errors.append((game_map2svm.GameMap, error))
if list_of_failed_maps_with_errors:
logging.info(
f"<{with_predictor.name()}> finished map {game_map2svm.GameMap.MapName} "
f"in {game_result.steps_count} steps, {time} seconds, "
f"actual coverage: {game_result.actual_coverage_percent:.2f}"
)
map2result = Map2Result(game_map2svm, game_result)
except (FunctionTimedOut, Exception) as error:
if isinstance(error, FunctionTimedOut):
log_message = f"<{with_predictor.name()}> timeouted on map {game_map2svm.GameMap.MapName} with {error.timedOutAfter}s"
else:
log_message = f"<{with_predictor.name()}> failed on map {game_map2svm.GameMap.MapName}:\n{traceback.format_exc()}"
logging.warning(log_message)
FeatureConfig.SAVE_IF_FAIL_OR_TIMEOUT.save_model(
with_predictor.model(), with_name=f"{with_predictor.name()}"
)
failed_maps, errors = list(zip(*list_of_failed_maps_with_errors))
raise GameErrors(errors, failed_maps)
raise GameError(game_map2svm, error)

return list_of_map2result
return map2result
13 changes: 6 additions & 7 deletions AIAgent/ml/training/epochs_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ def wrapper(self, *args, **kwargs):
return wrapper

@update_file
def fail(self, game_maps: list[GameMap2SVM]):
for game_map2svm in game_maps:
svm_name = game_map2svm.SVMInfo.name
svm_status: Status = self._svms_status.get(svm_name, Status())
svm_status.failed_maps.append(game_map2svm)
svm_status.count_of_failed_maps += 1
self._svms_status[svm_name] = svm_status
def fail(self, game_map2svm: GameMap2SVM):
svm_name = game_map2svm.SVMInfo.name
svm_status: Status = self._svms_status.get(svm_name, Status())
svm_status.failed_maps.append(game_map2svm)
svm_status.count_of_failed_maps += 1
self._svms_status[svm_name] = svm_status

def __clear_failed_maps(self):
for svm_status in self._svms_status.values():
Expand Down
23 changes: 12 additions & 11 deletions AIAgent/ml/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import tqdm
import mlflow
from ml.training.epochs_statistics import StatisticsCollector
from common.classes import Map2Result
from common.errors import GameErrors
from common.classes import Map2Result, GameMap2SVM
from common.errors import GameError
from config import GeneralConfig
from ml.inference import infer
from ml.play_game import play_game
Expand All @@ -30,13 +30,14 @@ def wrapper(*args, **kwargs):


@catch_return_exception
def play_game_task(task):
maps, dataset, wrapper = task

def play_game_task(
task: tuple[GameMap2SVM, TrainingDataset, TrainingModelWrapper],
) -> Map2Result:
game_map2svm, dataset, wrapper = task
result = play_game(
with_predictor=wrapper,
max_steps=GeneralConfig.MAX_STEPS,
maps=maps,
game_map2svm=game_map2svm,
with_dataset=dataset,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -67,7 +68,7 @@ def validate_coverage(
Your favorite colour for progress bar.
"""
wrapper = TrainingModelWrapper(model)
tasks = [([game_map2svm], dataset, wrapper) for game_map2svm in dataset.maps]
tasks = [(game_map2svm, dataset, wrapper) for game_map2svm in dataset.maps]
statistics_collector = StatisticsCollector(CURRENT_TABLE_PATH)
with mp.Pool(server_count) as p:
all_results: list[Map2Result] = list()
Expand All @@ -78,10 +79,10 @@ def validate_coverage(
ncols=100,
colour=progress_bar_colour,
):
if isinstance(result, GameErrors):
statistics_collector.fail(result.maps)
if isinstance(result, GameError):
statistics_collector.fail(result._map)
else:
all_results.extend(result)
all_results.append(result)

def avg_coverage(results, path_to_coverage: str) -> int:
coverage = np.average(
Expand All @@ -90,7 +91,7 @@ def avg_coverage(results, path_to_coverage: str) -> int:
return coverage

average_result = avg_coverage(
list(map(lambda map2result: map2result.game_result, result)),
list(map(lambda map2result: map2result.game_result, all_results)),
"actual_coverage_percent",
)
statistics_collector.update_results(average_result, all_results)
Expand Down

0 comments on commit 6ea4758

Please sign in to comment.