diff --git a/infero/main.py b/infero/main.py index b8c4638..231f847 100644 --- a/infero/main.py +++ b/infero/main.py @@ -17,6 +17,9 @@ @app.command("run") def run(model: str, quantize: bool = False): + convert_to_onnx(model) + if quantize: + convert_to_onnx_q8(model) if check_model(model): model_path = os.path.join(get_models_dir(), sanitize_model_name(model)) package_dir = get_package_dir() @@ -31,9 +34,6 @@ def run(model: str, quantize: bool = False): @app.command("pull") def pull(model: str, quantize: bool = False): if pull_model(model): - convert_to_onnx(model) - if quantize: - convert_to_onnx_q8(model) print_success_bold(f"Model {model} pulled successfully") else: print_error("Failed to get model")