-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge steps #27
Merge steps #27
Conversation
Anya497
commented
Jan 25, 2024
•
edited
Loading
edited
- Add merge_steps function to FullDataset class
- New way to distribute maps between processes during validation
- Swap errors and steps number in result tuple
- New path to server working directory
- Refactor data_loader and add parallelism to dataset processing
e22e828
to
9e68fc6
Compare
AIAgent/ml/common_model/dataset.py
Outdated
@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100): | |||
result = [] | |||
for map_result, map_steps in self.maps_data.values(): | |||
if map_result[0] >= threshold: | |||
for step in map_steps: | |||
if len(map_steps) > 2000: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
количество шагов надо указать как входное значение для функции
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ок
AIAgent/ml/common_model/dataset.py
Outdated
@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100): | |||
result = [] | |||
for map_result, map_steps in self.maps_data.values(): | |||
if map_result[0] >= threshold: | |||
for step in map_steps: | |||
if len(map_steps) > 2000: | |||
selected_steps = random.sample(map_steps, 2000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
вот тут она же видимо
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Да
AIAgent/ml/common_model/dataset.py
Outdated
@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100): | |||
result = [] | |||
for map_result, map_steps in self.maps_data.values(): | |||
if map_result[0] >= threshold: | |||
for step in map_steps: | |||
if len(map_steps) > 2000: | |||
selected_steps = random.sample(map_steps, 2000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
и тут важно, что именно random.sample? он выбирает шаги в случайном порядке. если len(map_steps) > 2000
будет false, то шаги будут в прямом. то, что в разных случаях разные порядки, не будет влиять на обучение?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Не будет. Там снаружи всё равно шафл всего и вся.
AIAgent/ml/common_model/dataset.py
Outdated
new_steps_num = len(self.maps_data[map_name][1]) | ||
logging.info( | ||
f"Steps on map {map_name} were merged with current steps with result {map_result}. {len(filtered_map_steps)} + {init_steps_num} -> {new_steps_num}. " | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ок
AIAgent/ml/common_model/dataset.py
Outdated
if self.maps_data[map_name][0] == map_result and map_result[0] == 100: | ||
init_steps_num = len(self.maps_data[map_name][1]) | ||
|
||
filtered_map_steps = self.remove_similar_steps(filtered_map_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's inline this variable, re-using it is serving no purpose. same on line 147
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
По-моему, есть противоречие с комментарием про единоразовое вычисление remove_similar_steps.
AIAgent/ml/common_model/dataset.py
Outdated
break | ||
if should_add: | ||
merged_steps.append(new_step) | ||
merged_steps.extend(sum(old_steps.values(), [])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
странный extend. что тут происходит?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum суммирует списки в один. А extend добавляет все в merged_steps
AIAgent/run_common_model_training.py
Outdated
|
||
all_average_results = [] | ||
for epoch in range(config.epochs): | ||
data_list = dataset.get_plain_data() | ||
data_list = dataset.get_plain_data(80) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keyword is needed. 80 is not linked with function name, so it is impossible to interpret without exploring get_plain_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Я скоро буду весь датасет сильно переписывать. Учту это в новой версии
AIAgent/ml/common_model/dataset.py
Outdated
@@ -91,7 +92,11 @@ def get_plain_data(self, threshold: int = 100): | |||
result = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- please rename threshold parameter for better clarity. threshold of what?
- what does this function do? why the data it gets is plain? why do we need a threshold? maybe function name should reflect that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
В следующей версии датасета эта функция будет не нужна.
AIAgent/run_common_model_training.py
Outdated
@@ -190,7 +191,7 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset): | |||
cmwrapper.make_copy(str(epoch + 1)) | |||
|
|||
with mp.Pool(GeneralConfig.SERVER_COUNT) as p: | |||
result = list(p.map(play_game_task, tasks)) | |||
result = list(p.map(play_game_task, tasks, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- constants should come with the keyword: chunksize=1
- why chunksize=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ок
AIAgent/run_common_model_training.py
Outdated
tasks = [ | ||
(maps[i], FullDataset("", ""), cmwrapper) | ||
for i in range(GeneralConfig.SERVER_COUNT) | ||
([all_maps[i]], FullDataset("", ""), cmwrapper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tasks = [
([concrete_map], FullDataset("", ""), cmwrapper)
for concrete_map in all_maps
]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ну да
AIAgent/common/constants.py
Outdated
@@ -47,4 +47,6 @@ class ResultsHandlerLinks: | |||
BASE_NN_OUT_FEATURES_NUM = 8 | |||
|
|||
# assuming we start from /VSharp/VSharp.ML.AIAgent | |||
SERVER_WORKING_DIR = "../VSharp.ML.GameServer.Runner/bin/Release/net7.0/" | |||
SERVER_WORKING_DIR = ( | |||
"../GameServers/VSharp/VSharp.ML.GameServer.Runner/bin/Release/net7.0/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pathlib.Path("...")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ага
…ew algorithm for parallel validation.
…st-training sequentially. Turn off weights loading.
…training. Add pretraining dataset generation.
.gitignore
Outdated
@@ -162,3 +162,4 @@ cython_debug/ | |||
|
|||
# MacOS specific | |||
.DS_Store | |||
AIAgent/report/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last line should be empty
@@ -3,4 +3,4 @@ repos: | |||
rev: 23.12.1 | |||
hooks: | |||
- id: black | |||
language_version: python3.11 | |||
language_version: python3.10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
???
from the docs:
It is recommended to specify the latest version of Python supported by your project here
AIAgent/ml/common_model/paths.py
Outdated
ROOT, "ml", "pretrained_models", "models_for_parallel_architecture" | ||
) | ||
PRETRAINED_MODEL_PATH = os.path.join("ml", "models") | ||
RAW_FILES_PATH = os.path.join("report", "SerializedEpisodes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does SerializedEpisodes
folder always in this location? Should it instead be passed as the cmd argument?
|
||
@dataclass(slots=True) | ||
class Step: | ||
Graph: TypeAlias = HeteroData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
а так можно что ли?) какой семантический смысл у этой конструкции?
f = open( | ||
file_path | ||
) # without resource manager in order to escape file descriptors leaks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super counterintuitive: file resource managers are used to disallow file descriptor leaks. what is the reason to not use it there?
AIAgent/run_common_model_training.py
Outdated
) | ||
dataset.save() | ||
dataset.load() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why save -> load?
AIAgent/run_common_model_training.py
Outdated
@@ -271,20 +270,19 @@ def main(): | |||
type=bool, | |||
help="set this flag if dataset generation is needed", | |||
action=argparse.BooleanOptionalAction, | |||
default=False, | |||
default=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So default user action should be to generate dataset? Wouldn't user update dataset more frequently in general?
AIAgent/run_common_model_training.py
Outdated
ref_model_initializer = lambda: RefStateModelEncoderLastLayer( | ||
hidden_channels=32, out_channels=8 | ||
) | ||
print(GeneralConfig.DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why print it twice?
AIAgent/run_common_model_training.py
Outdated
maps: list[GameMap], ref_model_init: t.Callable[[], torch.nn.Module] | ||
): | ||
global DATASET_BASE_PATH | ||
def generate_dataset(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is better to parametrise function, global constants are not good in general
AIAgent/run_common_model_training.py
Outdated
) | ||
dataset.save() | ||
dataset.load() | ||
return dataset | ||
|
||
|
||
def get_dataset(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parametrise?
…rticular file. Split running, training and validation by different files. Create dataloader. Delete support of models learned with genetic algorithms.
Big refactor
b9776a2
to
5016902
Compare