Skip to content

Commit

Permalink
Merge pull request #61 from liltom-eth/llama2-wrapper
Browse files Browse the repository at this point in the history
[FEATURE] Add model downloading script
  • Loading branch information
liltom-eth authored Aug 26, 2023
2 parents 6006219 + 01bb622 commit 2a007d8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ python -m llama2_wrapper.server --backend_type gptq

Navigate to http://localhost:8000/docs to see the OpenAPI documentation.

#### Basic settings

| Flag | Description |
| ---------------- | ------------------------------------------------------------ |
| `-h`, `--help` | Show this help message. |
| `--model_path` | The path to the model to use for generating completions. |
| `--backend_type` | Backend for llama2, options: llama.cpp, gptq, transformers |
| `--max_tokens` | Maximum context size. |
| `--load_in_8bit` | Whether to use bitsandbytes to run model in 8 bit mode (only for transformers models). |
| `--verbose` | Whether to print verbose output to stderr. |
| `--host` | API address |
| `--port` | API port |

## Benchmark

Run benchmark script to compute performance on your device, `benchmark.py` will load the same `.env` as `app.py`.:
Expand Down
Empty file.
59 changes: 59 additions & 0 deletions llama2_wrapper/download/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import argparse


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo_id",
type=str,
default="",
required=True,
help="Repo ID like 'TheBloke/Llama-2-7B-Chat-GGML' ",
)
parser.add_argument(
"--filename",
type=str,
default=None,
help="Filename like llama-2-7b-chat.ggmlv3.q4_0.bin",
)
parser.add_argument(
"--save_dir", type=str, default="./models", help="Directory to save models"
)

args = parser.parse_args()

repo_id = args.repo_id
save_dir = args.save_dir

if not os.path.exists(save_dir):
os.makedirs(save_dir)

if args.filename:
filename = args.filename
from huggingface_hub import hf_hub_download

print(f"Start downloading model {repo_id} {filename} to: {save_dir}")

hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=save_dir,
)
else:
repo_name = repo_id.split("/")[1]
save_path = os.path.join(save_dir, repo_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
print(f"Start downloading model {repo_id} to: {save_path}")

from huggingface_hub import snapshot_download

snapshot_download(
repo_id=repo_id,
local_dir=save_path,
)


if __name__ == "__main__":
main()
10 changes: 5 additions & 5 deletions llama2_wrapper/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ class Settings(BaseSettings):
default="llama.cpp",
description="Backend for llama2, options: llama.cpp, gptq, transformers",
)
max_tokens: int = Field(default=2048, ge=1, description="Maximum context size.")
max_tokens: int = Field(default=4000, ge=1, description="Maximum context size.")
load_in_8bit: bool = Field(
default=False,
description="Use bitsandbytes to run model in 8 bit mode (only for transformers models).",
description="`Whether to use bitsandbytes to run model in 8 bit mode (only for transformers models).",
)
verbose: bool = Field(
default=False,
description="Print verbose output to stderr.",
description="Whether to print verbose output to stderr.",
)
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
host: str = Field(default="localhost", description="API address")
port: int = Field(default=8000, description="API port")
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "llama2-wrapper"
version = "0.1.11"
version = "0.1.12"
description = "Use llama2-wrapper as your local llama2 backend for Generative Agents / Apps"
authors = ["liltom-eth <liltom.eth@gmail.com>"]
license = "MIT"
Expand Down

0 comments on commit 2a007d8

Please sign in to comment.