Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge steps #27

Closed
wants to merge 28 commits into from
Closed

Merge steps #27

wants to merge 28 commits into from

Conversation

Anya497
Copy link
Collaborator

@Anya497 Anya497 commented Jan 25, 2024

  • Add merge_steps function to FullDataset class
  • New way to distribute maps between processes during validation
  • Swap errors and steps number in result tuple
  • New path to server working directory
  • Refactor data_loader and add parallelism to dataset processing

@Anya497 Anya497 force-pushed the merge_steps branch 2 times, most recently from e22e828 to 9e68fc6 Compare January 25, 2024 14:51
@gsvgit gsvgit requested review from gsvgit and emnigma January 25, 2024 14:57
@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100):
result = []
for map_result, map_steps in self.maps_data.values():
if map_result[0] >= threshold:
for step in map_steps:
if len(map_steps) > 2000:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

количество шагов надо указать как входное значение для функции

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100):
result = []
for map_result, map_steps in self.maps_data.values():
if map_result[0] >= threshold:
for step in map_steps:
if len(map_steps) > 2000:
selected_steps = random.sample(map_steps, 2000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

вот тут она же видимо

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Да

@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100):
result = []
for map_result, map_steps in self.maps_data.values():
if map_result[0] >= threshold:
for step in map_steps:
if len(map_steps) > 2000:
selected_steps = random.sample(map_steps, 2000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

и тут важно, что именно random.sample? он выбирает шаги в случайном порядке. если len(map_steps) > 2000 будет false, то шаги будут в прямом. то, что в разных случаях разные порядки, не будет влиять на обучение?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Не будет. Там снаружи всё равно шафл всего и вся.

new_steps_num = len(self.maps_data[map_name][1])
logging.info(
f"Steps on map {map_name} were merged with current steps with result {map_result}. {len(filtered_map_steps)} + {init_steps_num} -> {new_steps_num}. "
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
similar steps removal is preformed unconditionally in all of the possible cases except one. maybe filter only once? If performance is critical, maybe do it lazily be declaring lambda-function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

if self.maps_data[map_name][0] == map_result and map_result[0] == 100:
init_steps_num = len(self.maps_data[map_name][1])

filtered_map_steps = self.remove_similar_steps(filtered_map_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's inline this variable, re-using it is serving no purpose. same on line 147

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

По-моему, есть противоречие с комментарием про единоразовое вычисление remove_similar_steps.

break
if should_add:
merged_steps.append(new_step)
merged_steps.extend(sum(old_steps.values(), []))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

странный extend. что тут происходит?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum суммирует списки в один. А extend добавляет все в merged_steps


all_average_results = []
for epoch in range(config.epochs):
data_list = dataset.get_plain_data()
data_list = dataset.get_plain_data(80)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyword is needed. 80 is not linked with function name, so it is impossible to interpret without exploring get_plain_data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я скоро буду весь датасет сильно переписывать. Учту это в новой версии

@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100):
result = []
Copy link
Collaborator

@emnigma emnigma Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. please rename threshold parameter for better clarity. threshold of what?
  2. what does this function do? why the data it gets is plain? why do we need a threshold? maybe function name should reflect that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

В следующей версии датасета эта функция будет не нужна.

@@ -190,7 +191,7 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset):
cmwrapper.make_copy(str(epoch + 1))

with mp.Pool(GeneralConfig.SERVER_COUNT) as p:
result = list(p.map(play_game_task, tasks))
result = list(p.map(play_game_task, tasks, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. constants should come with the keyword: chunksize=1
  2. why chunksize=1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

tasks = [
(maps[i], FullDataset("", ""), cmwrapper)
for i in range(GeneralConfig.SERVER_COUNT)
([all_maps[i]], FullDataset("", ""), cmwrapper)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tasks = [
            ([concrete_map], FullDataset("", ""), cmwrapper)
            for concrete_map in all_maps
        ]

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ну да

@@ -47,4 +47,6 @@ class ResultsHandlerLinks:
BASE_NN_OUT_FEATURES_NUM = 8

# assuming we start from /VSharp/VSharp.ML.AIAgent
SERVER_WORKING_DIR = "../VSharp.ML.GameServer.Runner/bin/Release/net7.0/"
SERVER_WORKING_DIR = (
"../GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Release/net7.0/"
Copy link
Collaborator

@emnigma emnigma Jan 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pathlib.Path("...")?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ага

.gitignore Outdated
@@ -162,3 +162,4 @@ cython_debug/

# MacOS specific
.DS_Store
AIAgent/report/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last line should be empty

@@ -3,4 +3,4 @@ repos:
rev: 23.12.1
hooks:
- id: black
language_version: python3.11
language_version: python3.10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

???
from the docs:

It is recommended to specify the latest version of Python supported by your project here

ROOT, "ml", "pretrained_models", "models_for_parallel_architecture"
)
PRETRAINED_MODEL_PATH = os.path.join("ml", "models")
RAW_FILES_PATH = os.path.join("report", "SerializedEpisodes")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does SerializedEpisodes folder always in this location? Should it instead be passed as the cmd argument?


@dataclass(slots=True)
class Step:
Graph: TypeAlias = HeteroData
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а так можно что ли?) какой семантический смысл у этой конструкции?

Comment on lines +238 to +236
f = open(
file_path
) # without resource manager in order to escape file descriptors leaks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super counterintuitive: file resource managers are used to disallow file descriptor leaks. what is the reason to not use it there?

)
dataset.save()
dataset.load()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why save -> load?

@@ -271,20 +270,19 @@ def main():
type=bool,
help="set this flag if dataset generation is needed",
action=argparse.BooleanOptionalAction,
default=False,
default=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So default user action should be to generate dataset? Wouldn't user update dataset more frequently in general?

ref_model_initializer = lambda: RefStateModelEncoderLastLayer(
hidden_channels=32, out_channels=8
)
print(GeneralConfig.DEVICE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why print it twice?

maps: list[GameMap], ref_model_init: t.Callable[[], torch.nn.Module]
):
global DATASET_BASE_PATH
def generate_dataset():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to parametrise function, global constants are not good in general

)
dataset.save()
dataset.load()
return dataset


def get_dataset():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrise?

Anya497 and others added 4 commits February 16, 2024 10:02
…rticular file. Split running, training and validation by different files. Create dataloader. Delete support of models learned with genetic algorithms.
@emnigma emnigma force-pushed the merge_steps branch 2 times, most recently from b9776a2 to 5016902 Compare March 24, 2024 15:45
@Anya497 Anya497 closed this Apr 12, 2024
@Anya497 Anya497 deleted the merge_steps branch April 12, 2024 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants