Skip to content

Commit

Permalink
arg handling (pytorch#292)
Browse files Browse the repository at this point in the history
* arg handling

* phase ordering issue resolved
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 57a7964 commit 3e50c42
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 37 deletions.
10 changes: 8 additions & 2 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,19 @@ def __post_init__(self):

@classmethod
def from_args(cls, args): # -> BuilderArgs:

# Handle disabled checkpoint_dir option
checkpoint_dir = None
if hasattr(args, "checkpoint_dir"):
checkpoint_dir = args.checkpoint_dir

is_chat_model = False
if args.is_chat_model:
is_chat_model = True
else:
for path in [
args.checkpoint_path,
args.checkpoint_dir,
checkpoint_dir,
args.dso_path,
args.pte_path,
args.gguf_path,
Expand All @@ -89,7 +95,7 @@ def from_args(cls, args): # -> BuilderArgs:

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
checkpoint_dir=checkpoint_dir,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
Expand Down
44 changes: 9 additions & 35 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,10 @@

import torch

default_device = "cpu"

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

strict = False


def check_args(args, command_name: str):
global strict

# chat and generate support the same options
if command_name in ["generate", "chat", "gui"]:
# examples, can add more. Note that attributes convert dash to _
disallowed_args = ["output_pte_path", "output_dso_path"]
elif command_name == "export":
# examples, can add more. Note that attributes convert dash to _
disallowed_args = ["pte_path", "dso_path"]
elif command_name == "eval":
# TBD
disallowed_args = []
else:
raise RuntimeError(f"{command_name} is not a valid command")

for disallowed in disallowed_args:
if hasattr(args, disallowed):
text = f"command {command_name} does not support option {disallowed.replace('_', '-')}"
if strict:
raise RuntimeError(text)
else:
print(f"Warning: {text}")

def check_args(args, name: str) -> None:
pass

def add_arguments_for_generate(parser):
# Only generate specific options should be here
Expand Down Expand Up @@ -123,12 +97,12 @@ def _add_arguments_common(parser):
default="not_specified",
help="Model checkpoint path.",
)
parser.add_argument(
"--checkpoint-dir",
type=Path,
default=None,
help="Model checkpoint directory.",
)
# parser.add_argument(
# "--checkpoint-dir",
# type=Path,
# default=None,
# help="Model checkpoint directory.",
# )
parser.add_argument(
"--params-path",
type=Path,
Expand Down

0 comments on commit 3e50c42

Please sign in to comment.