Skip to content

Commit

Permalink
Stabilize client-server communication
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Jan 23, 2024
1 parent 4c14a75 commit 88e8dc9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 8 deletions.
2 changes: 1 addition & 1 deletion AIAgent/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ class ResultsHandlerLinks:

# assuming we start from /VSharp/VSharp.ML.AIAgent
SERVER_WORKING_DIR = (
"GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Release/net7.0/"
"/Users/emax/Data/PySymGym/GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Debug/net7.0/"
)
3 changes: 2 additions & 1 deletion AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class GeneralConfig:
IMPORT_MODEL_INIT = ...
EXPORT_MODEL_INIT = ...
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DATASET_BASE_PATH = "/Users/emax/Data/PySymGym/maps/DotNet/Maps/ManuallyCollected/bin/Debug/net7.0"


class BrokerConfig:
Expand Down Expand Up @@ -73,7 +74,7 @@ class FeatureConfig:
enabled=True, save_path=Path("./report/epochs_tables/")
)
ON_GAME_SERVER_RESTART = OnGameServerRestartFeature(
enabled=False, wait_for_reset_retries=10 * 60, wait_for_reset_time=0.1
enabled=True, wait_for_reset_retries=10 * 60, wait_for_reset_time=0.1
)


Expand Down
8 changes: 6 additions & 2 deletions AIAgent/learning/play_game.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import logging
import os
from statistics import StatisticsError
from time import perf_counter
from typing import TypeAlias

import tqdm
from func_timeout import FunctionTimedOut, func_set_timeout

from common.classes import GameResult, Map2Result
from common.constants import TQDM_FORMAT_DICT
from common.game import GameMap
from common.utils import get_states
from config import FeatureConfig, GeneralConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from connection.game_server_conn.connector import Connector
from func_timeout import FunctionTimedOut, func_set_timeout
from learning.timer.resources_manager import manage_map_inference_times_array
from learning.timer.stats import compute_statistics
from learning.timer.utils import get_map_inference_times
Expand Down Expand Up @@ -146,6 +146,10 @@ def play_game(
) as pbar:
list_of_map2result: list[Map2Result] = []
for game_map in maps:
game_map.AssemblyFullName = os.path.join(
GeneralConfig.DATASET_BASE_PATH, game_map.AssemblyFullName
)

logging.info(f"<{with_predictor.name()}> is playing {game_map.MapName}")

try:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch_geometric.nn import Linear
from torch.nn.functional import softmax
from .model import StateModelEncoder


class StateModelEncoderLastLayer(StateModelEncoder):
def __init__(self, hidden_channels, out_channels):
super().__init__(hidden_channels, out_channels)
self.lin_last = Linear(out_channels, 1)

def forward(
self,
game_x,
state_x,
edge_index_v_v,
edge_type_v_v,
edge_index_history_v_s,
edge_attr_history_v_s,
edge_index_in_v_s,
edge_index_s_s,
):
return softmax(
self.lin_last(
super().forward(
game_x=game_x,
state_x=state_x,
edge_index_v_v=edge_index_v_v,
edge_type_v_v=edge_type_v_v,
edge_index_history_v_s=edge_index_history_v_s,
edge_attr_history_v_s=edge_attr_history_v_s,
edge_index_in_v_s=edge_index_in_v_s,
edge_index_s_s=edge_index_s_s,
)
),
dim=0,
)
6 changes: 2 additions & 4 deletions AIAgent/run_common_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,19 @@
import os
import random
import typing as t
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path

import joblib
import numpy as np
import optuna
import pandas as pd
import torch
import torch.nn as nn
import tqdm
from common.game import GameMap
from config import GeneralConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from epochs_statistics.tables import create_pivot_table, table_to_string
from learning.play_game import play_game
from ml.common_model.dataset import FullDataset
Expand All @@ -33,7 +31,7 @@
)
from ml.common_model.utils import csv2best_models, get_model
from ml.common_model.wrapper import BestModelsWrapper, CommonModelWrapper
from ml.models.RGCNEdgeTypeTAG2VerticesDouble.model_modified import (
from ml.models.RGCNEdgeTypeTAG3VerticesDoubleHistory2.model_modified import (
StateModelEncoderLastLayer,
)
from ml.models.StateGNNEncoderConvEdgeAttr.model_modified import (
Expand Down

0 comments on commit 88e8dc9

Please sign in to comment.