Skip to content

Commit

Permalink
blacked quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
Geeks-Sid committed Mar 29, 2023
1 parent 382c0c6 commit 172febf
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,36 @@ def main():
quantize_script_binary = "quantize"

parser = argparse.ArgumentParser(
prog='python3 quantize.py',
description='This script quantizes the given models by applying the '
f'"{quantize_script_binary}" script on them.'
prog="python3 quantize.py",
description="This script quantizes the given models by applying the "
f'"{quantize_script_binary}" script on them.',
)
parser.add_argument(
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
"models",
nargs="+",
choices=("7B", "13B", "30B", "65B"),
help="The models to quantize.",
)
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
"-r",
"--remove-16",
action="store_true",
dest="remove_f16",
help="Remove the f16 model after quantizing it.",
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
"-m",
"--models-path",
dest="models_path",
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
help="Specify the directory where the models are located.",
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
"-q",
"--quantize-script-path",
dest="quantize_script_path",
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
help='Specify the path to the "quantize" script.',
)

# TODO: Revise this code
Expand Down Expand Up @@ -75,12 +84,12 @@ def main():
)

if not os.path.isfile(f16_model_path_base):
print(f'The file %s was not found' % f16_model_path_base)
print(f"The file %s was not found" % f16_model_path_base)
sys.exit(1)

f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
glob.glob(f"{f16_model_path_base}*"),
)

for f16_model_part_path in f16_model_parts_paths:
Expand All @@ -93,9 +102,7 @@ def main():
)
sys.exit(1)

__run_quantize_script(
args.quantize_script_path, f16_model_part_path
)
__run_quantize_script(args.quantize_script_path, f16_model_part_path)

if args.remove_f16:
os.remove(f16_model_part_path)
Expand All @@ -104,15 +111,15 @@ def main():
# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406


def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""

new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
subprocess.run(
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
[script_path, f16_model_part_path, new_quantized_model_path, "2"], check=True
)


Expand Down

0 comments on commit 172febf

Please sign in to comment.