|
| 1 | +import os |
| 2 | +import argparse |
| 3 | + |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from typing import List, Optional |
| 6 | + |
| 7 | +# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class GptParams: |
| 12 | + seed: int = -1 |
| 13 | + n_threads: int = min(4, os.cpu_count() or 1) |
| 14 | + n_predict: int = 128 |
| 15 | + repeat_last_n: int = 64 |
| 16 | + n_parts: int = -1 |
| 17 | + n_ctx: int = 512 |
| 18 | + n_batch: int = 8 |
| 19 | + n_keep: int = 0 |
| 20 | + |
| 21 | + top_k: int = 40 |
| 22 | + top_p: float = 0.95 |
| 23 | + temp: float = 0.80 |
| 24 | + repeat_penalty: float = 1.10 |
| 25 | + |
| 26 | + model: str = "./models/llama-7B/ggml-model.bin" |
| 27 | + prompt: str = "" |
| 28 | + input_prefix: str = " " |
| 29 | + |
| 30 | + antiprompt: List[str] = field(default_factory=list) |
| 31 | + |
| 32 | + memory_f16: bool = True |
| 33 | + random_prompt: bool = False |
| 34 | + use_color: bool = False |
| 35 | + interactive: bool = False |
| 36 | + |
| 37 | + embedding: bool = False |
| 38 | + interactive_start: bool = False |
| 39 | + |
| 40 | + instruct: bool = False |
| 41 | + ignore_eos: bool = False |
| 42 | + perplexity: bool = False |
| 43 | + use_mlock: bool = False |
| 44 | + mem_test: bool = False |
| 45 | + verbose_prompt: bool = False |
| 46 | + |
| 47 | + file: str = None |
| 48 | + |
| 49 | + # If chat ended prematurely, append this to the conversation to fix it. |
| 50 | + # Set to "\nUser:" etc. |
| 51 | + # This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:"" |
| 52 | + fix_prefix: str = " " |
| 53 | + output_postfix: str = "" |
| 54 | + input_echo: bool = True, |
| 55 | + |
| 56 | + # Default instructions for Alpaca |
| 57 | + # switch to "Human" and "Assistant" for Vicuna. |
| 58 | + # TODO: TBD how they are gonna handle this upstream |
| 59 | + instruct_inp_prefix: str="\n\n### Instruction:\n\n" |
| 60 | + instruct_inp_suffix: str="\n\n### Response:\n\n" |
| 61 | + |
| 62 | + |
| 63 | +def gpt_params_parse(argv = None, params: Optional[GptParams] = None): |
| 64 | + if params is None: |
| 65 | + params = GptParams() |
| 66 | + |
| 67 | + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 68 | + parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed") |
| 69 | + parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads") |
| 70 | + parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt") |
| 71 | + parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file") |
| 72 | + parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx") |
| 73 | + parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16") |
| 74 | + parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p") |
| 75 | + parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k") |
| 76 | + parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp") |
| 77 | + parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict") |
| 78 | + parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n") |
| 79 | + parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty") |
| 80 | + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch") |
| 81 | + parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep") |
| 82 | + parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model") |
| 83 | + parser.add_argument( |
| 84 | + "-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive" |
| 85 | + ) |
| 86 | + parser.add_argument("--embedding", action="store_true", help="", dest="embedding") |
| 87 | + parser.add_argument( |
| 88 | + "--interactive-start", |
| 89 | + action="store_true", |
| 90 | + help="run in interactive mode", |
| 91 | + dest="interactive" |
| 92 | + ) |
| 93 | + parser.add_argument( |
| 94 | + "--interactive-first", |
| 95 | + action="store_true", |
| 96 | + help="run in interactive mode and wait for input right away", |
| 97 | + dest="interactive_start" |
| 98 | + ) |
| 99 | + parser.add_argument( |
| 100 | + "-ins", |
| 101 | + "--instruct", |
| 102 | + action="store_true", |
| 103 | + help="run in instruction mode (use with Alpaca or Vicuna models)", |
| 104 | + dest="instruct" |
| 105 | + ) |
| 106 | + parser.add_argument( |
| 107 | + "--color", |
| 108 | + action="store_true", |
| 109 | + help="colorise output to distinguish prompt and user input from generations", |
| 110 | + dest="use_color" |
| 111 | + ) |
| 112 | + parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock") |
| 113 | + parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test") |
| 114 | + parser.add_argument( |
| 115 | + "-r", |
| 116 | + "--reverse-prompt", |
| 117 | + type=str, |
| 118 | + action='append', |
| 119 | + help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).", |
| 120 | + dest="antiprompt" |
| 121 | + ) |
| 122 | + parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity") |
| 123 | + parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos") |
| 124 | + parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts") |
| 125 | + parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt") |
| 126 | + parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix") |
| 127 | + parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix") |
| 128 | + parser.add_argument("--out-postfix", type=str, default="", help="append to input", dest="output_postfix") |
| 129 | + parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo") |
| 130 | + args = parser.parse_args(argv) |
| 131 | + return args |
| 132 | + |
| 133 | +def gpt_random_prompt(rng): |
| 134 | + return [ |
| 135 | + "So", |
| 136 | + "Once upon a time", |
| 137 | + "When", |
| 138 | + "The", |
| 139 | + "After", |
| 140 | + "If", |
| 141 | + "import", |
| 142 | + "He", |
| 143 | + "She", |
| 144 | + "They", |
| 145 | + ][rng % 10] |
| 146 | + |
| 147 | +if __name__ == "__main__": |
| 148 | + print(GptParams(gpt_params_parse())) |
0 commit comments