Skip to content

Commit

Permalink
remove subparser
Browse files Browse the repository at this point in the history
  • Loading branch information
Damowerko committed Mar 11, 2024
1 parent 267040f commit 4f861cf
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions scripts/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import os
import sys
import typing
from functools import partial
from pathlib import Path
from typing import List, Union
from typing import List

import optuna
import pytorch_lightning as pl
Expand All @@ -30,19 +31,19 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument("operation", type=str, choices=["train", "test", "study"])
parser.add_argument("model", type=str, choices=models.keys())

# program arguments
group = parser.add_argument_group("General")
group.add_argument("operation", type=str, choices=["train", "test", "study"])
group.add_argument("--no_log", action="store_true")
group.add_argument("--log_dir", type=str, default="./logs")
group.add_argument("--data_dir", type=str, default="./data/train")

# model arguments
subparsers = parser.add_subparsers(title="Model", dest="model", required=True)
for name, model in models.items():
subparser = subparsers.add_parser(name, add_help=True)
subparser = model.add_model_specific_args(subparser)
model_name = sys.argv[2]
group = parser.add_argument_group("Model Hyperparameters")
models[model_name].add_model_specific_args(group)

# data arguments
group = parser.add_argument_group("Data")
Expand Down

0 comments on commit 4f861cf

Please sign in to comment.