Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement download subcommand, optional positional model name argument #234

Merged
merged 14 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ jobs:
cat ./output_eager2
echo "Tests complete."

- name: Test download
run: |

python torchchat.py generate stories15M
swolchok marked this conversation as resolved.
Show resolved Hide resolved

test-tinystories-eager:
strategy:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ __pycache__/

# C extensions
*.so

.model-artifacts/
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@ python torchchat.py --help

```

### Dowenload a Model and Tokenizer
### Generating Text

```
#download a model
python torchchat.py download llama2

#generate text using the model

python torchchat.py generate stories15M
```
That’s all there is to it!
Read on to learn how to use the full power of torchchat.
Expand All @@ -48,7 +44,15 @@ Read on to learn how to use the full power of torchchat.
For the full details on all commands and parameters run `python torchchat.py --help`

### Download
TODO: Fill this out
For supported models, torchchat can download model weights. Most models use HuggingFace as the distribution channel, so you will need to create a HuggingFace
account and install `huggingface-cli`.

To install `huggingface-cli`, run `pip install huggingface-cli`. After installing, create a user access token [as documented here](https://huggingface.co/docs/hub/en/security-tokens). Run `huggingface-cli login`, which will prompt for the newly created token. Once this is done, torchchat will be able to download model artifacts from
HuggingFace.

```
python torchchat.py download llama2
```

### Chat
Designed for interactive and conversational use.
Expand All @@ -69,7 +73,7 @@ For more information run `python torchchat.py generate --help`

**Examples**
```
#Generate for Mac with some parameters
python torchchat.py generate llama2 --device=cpu --dtype=fp16
```

### Export
Expand All @@ -80,7 +84,7 @@ For more information run `python torchchat.py export --help`
**Examples**

```
#Export Example
python torchchat.py export stories15M --output-pte-path=stories15m.pte
```

### Browser
Expand Down
31 changes: 25 additions & 6 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config

from config.model_config import resolve_model_config
from quantize import name_to_dtype, quantize_model

from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -42,7 +42,7 @@ class BuilderArgs:
def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
Expand Down Expand Up @@ -73,7 +73,17 @@ def from_args(cls, args): # -> BuilderArgs:
# Handle disabled checkpoint_dir option
checkpoint_dir = None
if hasattr(args, "checkpoint_dir"):
checkpoint_dir = args.checkpoint_dir
checkpoint_dir = args.checkpoint_dir

checkpoint_path = args.checkpoint_path
if args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)

checkpoint_path = (
Path(args.model_directory)
/ model_config.name
/ model_config.checkpoint_file
)

is_chat_model = False
if args.is_chat_model:
Expand All @@ -94,8 +104,8 @@ def from_args(cls, args): # -> BuilderArgs:
is_chat_model = True

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=checkpoint_dir,
checkpoint_path=checkpoint_path,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
Expand Down Expand Up @@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs:

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well known doesn't mean it's local. how do you know where the tokenizer is?

elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
elif args.checkpoint_dir:
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError("cannot find tokenizer model")
Expand Down Expand Up @@ -356,4 +369,10 @@ def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
is_tiktoken = tokenizer_args.is_tiktoken
if use_tiktoken != is_tiktoken:
raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}")


def resolve_model_name(model: str) -> str:
# If the provided model name is an alias, retrieve the full path.
if model in model_aliases:
return model_aliases[model]
else:
return model
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
from pathlib import Path
Expand All @@ -22,19 +23,20 @@
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Optional[Path] = None,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
) -> None:
if checkpoint_dir is None:
checkpoint_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
Comment on lines +30 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the default something that's not even in models.json?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikekgfb Do we need this default value anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but we should put this or a similarly situated chat model into the models.json.

BTW, I really think it's bad to have even the model name default to something (unless we're so excited about llama3 that we make it that.... but that will require users to have obtained a token)

if model_name is None:
model_name = checkpoint_dir.name
model_name = model_dir.name

config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = model_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

Expand All @@ -56,7 +58,7 @@ def convert_hf_checkpoint(
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
dim = config.dim
Expand Down Expand Up @@ -97,8 +99,13 @@ def permute(w, n_heads):
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, model_dir / "model.pth")
print("Done.")

if remove_bin_files:
for file in bin_files:
os.remove(file)


if __name__ == "__main__":
Expand All @@ -114,6 +121,6 @@ def permute(w, n_heads):

args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint_dir,
model_dir=args.checkpoint_dir,
model_name=args.model_name,
)
30 changes: 30 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
def check_args(args, name: str) -> None:
pass

def add_arguments_for_download(parser):
# Only download specific options should be here
_add_arguments_common(parser)


def add_arguments_for_generate(parser):
# Only generate specific options should be here
_add_arguments_common(parser)
Expand All @@ -39,6 +44,19 @@ def add_arguments_for_browser(parser):
)

def _add_arguments_common(parser):
# Model specification. TODO Simplify this.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, i'd add a lot more detail to the --help output. In general expect the user doesn't know the term and wants to know how that param will be used. The user should be able to figure out exactly what to do just by using --help and have no prior experience using an LLM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I did actually just re-order the args here, so this isn't a new parameter. However, I do agree so I think we should take this as a follow-up.

# A model can be specified using a positional model name or HuggingFace
# path. Alternatively, the model can be specified via --gguf-path or via
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.

parser.add_argument(
"model",
type=str,
nargs="?",
default=None,
help="Model name for well-known models.",
)

# TODO: Refactor this so that only common options are here
# and subcommand-specific options are inside individual
# add_arguments_for_generate, add_arguments_for_export etc.
Expand Down Expand Up @@ -168,6 +186,18 @@ def _add_arguments_common(parser):
default=None,
help="maximum length sequence to evaluate",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="A HuggingFace API token to use when downloading model artifacts",
)
parser.add_argument(
"--model-directory",
type=Path,
default=".model-artifacts",
help="The directory to store downloaded model artifacts",
)


def arg_init(args):
Expand Down
Empty file added config/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions config/data/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"meta-llama/Llama-2-7b-chat-hf": {
"aliases": ["llama2", "llama2-7b"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-2-7b-chat-hf"
},
"mistralai/Mistral-7B-Instruct-v0.2": {
"aliases": ["mistral-7b-instruct"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2"
},
"stories15M": {
"distribution_channel": "DirectDownload",
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
],
"checkpoint_file": "stories15M.pt"
},
"stories110M": {
"distribution_channel": "DirectDownload",
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
],
"checkpoint_file": "stories110M.pt"
}
}
86 changes: 86 additions & 0 deletions config/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence, Union

"""
Known Model Configs:

For models that are known to work with torchchat, we provide a config under
config/data/models.json to support automatically downloading the model and
converting to the expected format for use with torchchat.

There are two supported distribution channels:

1) HuggingFaceSnapshot: Download a model from HuggingFace.
2) DirectDownload: Download a list of model artifacts from URLs. No conversion
is done.
"""


# Specifies the distribution channel to download model artifacts from. Enum
# variants are specified as strings to simplify JSON (de)serialization.
class ModelDistributionChannel(str, Enum):
# Download a full model snapshot from HuggingFace, such as
# meta-llama/Llama-2-7b-chat-hf and convert to torchchat format.
HuggingFaceSnapshot = "HuggingFaceSnapshot"

# Download one or more files over HTTP(S).
DirectDownload = "DirectDownload"


@dataclass
class ModelConfig:
name: str = field(default="")
aliases: Sequence[str] = field(default_factory=list)
distribution_path: Union[str, Sequence[str]] = field(default="")
distribution_channel: ModelDistributionChannel = field(
default=ModelDistributionChannel.HuggingFaceSnapshot
)
checkpoint_file: str = field(default="model.pth")


# Keys are stored in lowercase.
model_aliases: Dict[str, str] = None
model_configs: Dict[str, ModelConfig] = None


def resolve_model_config(model: str) -> ModelConfig:
global model_aliases
global model_configs

model = model.lower()

# Lazy load model config from JSON.
if not model_configs:
model_aliases = {}
model_configs = {}

with open(
Path(__file__).parent.parent / "config" / "data" / "models.json", "r"
) as f:
model_config_dict = json.load(f)

for key, value in model_config_dict.items():
config = ModelConfig(**value)
config.name = key

key = key.lower()
model_configs[key] = config

for alias in config.aliases:
model_aliases[alias.lower()] = key

if model in model_aliases:
model = model_aliases[model]
Comment on lines +80 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: model = model_aliases.get(model, model) is shorter FWIW


if model not in model_configs:
raise ValueError(f"Unknown model '{model}'.")

return model_configs[model]
Loading
Loading