diff --git a/quantize.py b/quantize.py index 641df8dda1b1e..b01307dbd192e 100644 --- a/quantize.py +++ b/quantize.py @@ -25,27 +25,36 @@ def main(): quantize_script_binary = "quantize" parser = argparse.ArgumentParser( - prog='python3 quantize.py', - description='This script quantizes the given models by applying the ' - f'"{quantize_script_binary}" script on them.' + prog="python3 quantize.py", + description="This script quantizes the given models by applying the " + f'"{quantize_script_binary}" script on them.', ) parser.add_argument( - 'models', nargs='+', choices=('7B', '13B', '30B', '65B'), - help='The models to quantize.' + "models", + nargs="+", + choices=("7B", "13B", "30B", "65B"), + help="The models to quantize.", ) parser.add_argument( - '-r', '--remove-16', action='store_true', dest='remove_f16', - help='Remove the f16 model after quantizing it.' + "-r", + "--remove-16", + action="store_true", + dest="remove_f16", + help="Remove the f16 model after quantizing it.", ) parser.add_argument( - '-m', '--models-path', dest='models_path', + "-m", + "--models-path", + dest="models_path", default=os.path.join(os.getcwd(), "models"), - help='Specify the directory where the models are located.' + help="Specify the directory where the models are located.", ) parser.add_argument( - '-q', '--quantize-script-path', dest='quantize_script_path', + "-q", + "--quantize-script-path", + dest="quantize_script_path", default=os.path.join(os.getcwd(), quantize_script_binary), - help='Specify the path to the "quantize" script.' + help='Specify the path to the "quantize" script.', ) # TODO: Revise this code @@ -75,12 +84,12 @@ def main(): ) if not os.path.isfile(f16_model_path_base): - print(f'The file %s was not found' % f16_model_path_base) + print(f"The file %s was not found" % f16_model_path_base) sys.exit(1) f16_model_parts_paths = map( lambda filename: os.path.join(f16_model_path_base, filename), - glob.glob(f"{f16_model_path_base}*") + glob.glob(f"{f16_model_path_base}*"), ) for f16_model_part_path in f16_model_parts_paths: @@ -93,9 +102,7 @@ def main(): ) sys.exit(1) - __run_quantize_script( - args.quantize_script_path, f16_model_part_path - ) + __run_quantize_script(args.quantize_script_path, f16_model_part_path) if args.remove_f16: os.remove(f16_model_part_path) @@ -104,6 +111,7 @@ def main(): # This was extracted to a top-level function for parallelization, if # implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406 + def __run_quantize_script(script_path, f16_model_part_path): """Run the quantize script specifying the path to it and the path to the f16 model to quantize. @@ -111,8 +119,7 @@ def __run_quantize_script(script_path, f16_model_part_path): new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0") subprocess.run( - [script_path, f16_model_part_path, new_quantized_model_path, "2"], - check=True + [script_path, f16_model_part_path, new_quantized_model_path, "2"], check=True )