Skip to content

Commit

Permalink
Swap errors and steps number, refactor in merge_steps function, use n…
Browse files Browse the repository at this point in the history
…ew algorithm for parallel validation.
  • Loading branch information
Anya497 committed Jan 25, 2024
1 parent 9b9075d commit 9e68fc6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 30 deletions.
4 changes: 3 additions & 1 deletion AIAgent/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ class ResultsHandlerLinks:
BASE_NN_OUT_FEATURES_NUM = 8

# assuming we start from /VSharp/VSharp.ML.AIAgent
SERVER_WORKING_DIR = "../GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Release/net7.0/"
SERVER_WORKING_DIR = (
"../GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Release/net7.0/"
)
2 changes: 1 addition & 1 deletion AIAgent/learning/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def add_single_step(input, output):
map_result = (
model_result.actual_coverage_percent,
-model_result.tests_count,
model_result.errors_count,
-model_result.steps_count,
model_result.errors_count,
)
with_dataset.update(with_connector.map.MapName, map_result, map_steps)
return model_result, end_time - start_time
Expand Down
55 changes: 35 additions & 20 deletions AIAgent/ml/common_model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import numpy as np
import random

import tqdm
import logging
Expand Down Expand Up @@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100):
result = []
for map_result, map_steps in self.maps_data.values():
if map_result[0] >= threshold:
for step in map_steps:
if len(map_steps) > 2000:
selected_steps = random.sample(map_steps, 2000)
else:
selected_steps = map_steps
for step in selected_steps:
if step.use_for_train:
result.append(step)
return result
Expand Down Expand Up @@ -125,12 +130,15 @@ def update(
x.to("cpu")
filtered_map_steps = self.filter_map_steps(map_steps)
if map_name in self.maps_data.keys():
if self.maps_data[map_name][0] == map_result:
logging.info(
f"Steps on map {map_name} were merged with current steps with result {map_result}"
)
if self.maps_data[map_name][0] == map_result and map_result[0] == 100:
init_steps_num = len(self.maps_data[map_name][1])

filtered_map_steps = self.remove_similar_steps(filtered_map_steps)
self.merge_steps(filtered_map_steps, map_name)
new_steps_num = len(self.maps_data[map_name][1])
logging.info(
f"Steps on map {map_name} were merged with current steps with result {map_result}. {len(filtered_map_steps)} + {init_steps_num} -> {new_steps_num}. "
)
if self.maps_data[map_name][0] < map_result:
logging.info(
f"The model with result = {self.maps_data[map_name][0]} was replaced with the model with "
Expand Down Expand Up @@ -177,21 +185,28 @@ def flatten_and_sort_hetero_data(step: HeteroData):
old_steps = create_dict(self.maps_data[map_name][1])

for vertex_num in new_steps.keys():
flattened_old_steps = []
if vertex_num in old_steps.keys():
for old_step in old_steps[vertex_num]:
flattened_old_steps.append(flatten_and_sort_hetero_data(old_step))
for new_step in new_steps[vertex_num]:
if vertex_num in old_steps.keys():
for step_num, old_step in enumerate(old_steps[vertex_num]):
new_g_v, new_s_v = flatten_and_sort_hetero_data(new_step)
old_g_v, old_s_v = flatten_and_sort_hetero_data(old_step)

if np.array_equal(new_g_v, old_g_v) and np.array_equal(
new_s_v, old_s_v
):
y_true_sum = old_step.y_true + new_step.y_true
y_true_sum[y_true_sum != 0] = 1

old_step.y_true = y_true_sum / torch.sum(y_true_sum)
old_steps[vertex_num][step_num] = old_step
break
merged_steps.append(new_step)
new_g_v, new_s_v = flatten_and_sort_hetero_data(new_step)
should_add = True
for step_num, (old_g_v, old_s_v) in enumerate(flattened_old_steps):
if np.array_equal(new_g_v, old_g_v) and np.array_equal(
new_s_v, old_s_v
):
y_true_sum = (
old_steps[vertex_num][step_num].y_true + new_step.y_true
)
y_true_sum[y_true_sum != 0] = 1

old_steps[vertex_num][step_num].y_true = y_true_sum / torch.sum(
y_true_sum
)
should_add = False
break
if should_add:
merged_steps.append(new_step)
merged_steps.extend(sum(old_steps.values(), []))
self.maps_data[map_name] = (self.maps_data[map_name][0], merged_steps)
12 changes: 4 additions & 8 deletions AIAgent/run_common_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,16 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset):

with game_server_socket_manager() as ws:
all_maps = get_maps(websocket=ws, type=MapsType.TRAIN)
maps = np.array_split(all_maps, GeneralConfig.SERVER_COUNT)
random.shuffle(maps)
tasks = [
(maps[i], FullDataset("", ""), cmwrapper)
for i in range(GeneralConfig.SERVER_COUNT)
([all_maps[i]], FullDataset("", ""), cmwrapper)
for i in range(len(all_maps))
]

mp.set_start_method("spawn", force=True)
# p = Pool(GeneralConfig.SERVER_COUNT)

all_average_results = []
for epoch in range(config.epochs):
data_list = dataset.get_plain_data()
data_list = dataset.get_plain_data(80)
data_loader = DataLoader(data_list, batch_size=config.batch_size, shuffle=True)
print("DataLoader size", len(data_loader))

Expand Down Expand Up @@ -194,7 +191,7 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset):
cmwrapper.make_copy(str(epoch + 1))

with mp.Pool(GeneralConfig.SERVER_COUNT) as p:
result = list(p.map(play_game_task, tasks))
result = list(p.map(play_game_task, tasks, 1))

all_results = []
for maps_result, maps_data in result:
Expand Down Expand Up @@ -228,7 +225,6 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset):
torch.save(model.state_dict(), Path(path_to_model))
del data_list
del data_loader
# p.close()

return max(all_average_results)

Expand Down

0 comments on commit 9e68fc6

Please sign in to comment.