diff --git a/cli.py b/cli.py index a8c8013a6..4374b8953 100644 --- a/cli.py +++ b/cli.py @@ -80,12 +80,6 @@ def _add_arguments_common(parser): help="Model name for well-known models", ) - -def add_arguments(parser): - # TODO: Refactor this so that only common options are here - # and command-specific options are inside individual - # add_arguments_for_generate, add_arguments_for_export etc. - parser.add_argument( "--chat", action="store_true", @@ -301,10 +295,10 @@ def add_arguments(parser): def arg_init(args): - if Path(args.quantize).is_file(): + if hasattr(args, 'quantize') and Path(args.quantize).is_file(): with open(args.quantize, "r") as f: args.quantize = json.loads(f.read()) - if args.seed: + if hasattr(args, 'seed') and args.seed: torch.manual_seed(args.seed) return args diff --git a/eval.py b/eval.py index 4739c34f7..aa9b8a5c7 100644 --- a/eval.py +++ b/eval.py @@ -20,7 +20,7 @@ from build.model import Transformer from build.utils import set_precision -from cli import add_arguments, add_arguments_for_eval, arg_init +from cli import add_arguments_for_eval, arg_init from generate import encode_tokens, model_forward torch._dynamo.config.automatic_dynamic_shapes = True @@ -289,7 +289,6 @@ def main(args) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="torchchat eval CLI") - add_arguments(parser) add_arguments_for_eval(parser) args = parser.parse_args() args = arg_init(args) diff --git a/export.py b/export.py index b3fe0af70..549d01f5c 100644 --- a/export.py +++ b/export.py @@ -19,7 +19,7 @@ ) from build.utils import set_backend, set_precision -from cli import add_arguments, add_arguments_for_export, arg_init, check_args +from cli import add_arguments_for_export, arg_init, check_args from export_aoti import export_model as export_model_aoti try: @@ -104,7 +104,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="torchchat export CLI") - add_arguments(parser) add_arguments_for_export(parser) args = parser.parse_args() check_args(args, "export") diff --git a/generate.py b/generate.py index b61efbb8f..9aa19e6f1 100644 --- a/generate.py +++ b/generate.py @@ -25,7 +25,7 @@ ) from build.model import Transformer from build.utils import device_sync, set_precision -from cli import add_arguments, add_arguments_for_generate, arg_init, check_args +from cli import add_arguments_for_generate, arg_init, check_args logger = logging.getLogger(__name__) @@ -710,7 +710,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="torchchat generate CLI") - add_arguments(parser) add_arguments_for_generate(parser) args = parser.parse_args() check_args(args, "generate") diff --git a/torchchat.py b/torchchat.py index 5e2df0f3b..396283c47 100644 --- a/torchchat.py +++ b/torchchat.py @@ -10,7 +10,6 @@ import sys from cli import ( - add_arguments, add_arguments_for_browser, add_arguments_for_chat, add_arguments_for_download, @@ -30,17 +29,14 @@ # Initialize the top-level parser parser = argparse.ArgumentParser( prog="torchchat", - description="Welcome to the torchchat CLI!", add_help=True, ) - # Default command is to print help - parser.set_defaults(func=parser.print_help()) - add_arguments(parser) subparsers = parser.add_subparsers( dest="command", help="The specific command to run", ) + subparsers.required = True parser_chat = subparsers.add_parser( "chat", @@ -90,20 +86,6 @@ ) add_arguments_for_remove(parser_remove) - # Move all flags to the front of sys.argv since we don't - # want to use the subparser syntax - flag_args = [] - positional_args = [] - i = 1 - while i < len(sys.argv): - if sys.argv[i].startswith("-"): - flag_args += sys.argv[i : i + 2] - i += 2 - else: - positional_args.append(sys.argv[i]) - i += 1 - sys.argv = sys.argv[:1] + flag_args + positional_args - # Now parse the arguments args = parser.parse_args() args = arg_init(args)