Skip to content

Commit

Permalink
added tabular lists
Browse files Browse the repository at this point in the history
  • Loading branch information
ameen-91 committed Nov 29, 2024
1 parent 174a6ca commit 2f96713
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 15 deletions.
36 changes: 27 additions & 9 deletions infero/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import os
import subprocess
import typer
from infero.pull.download import check_model
from infero.pull.download import check_model, pull_model
from tabulate import tabulate
from infero.convert.onnx import convert_to_onnx, convert_to_onnx_q8
from infero.utils import (
sanitize_model_name,
get_models_dir,
get_package_dir,
print_neutral,
print_success_bold,
print_error,
)
from infero.pull.models import remove_model

app = typer.Typer(name="infero")


@app.command("run")
def pull(model: str, quantize: bool = False):
print
def run(model: str, quantize: bool = False):
if check_model(model):
convert_to_onnx(model)
if quantize:
convert_to_onnx_q8(model)
model_path = os.path.join(get_models_dir(), sanitize_model_name(model))
package_dir = get_package_dir()
server_script_path = os.path.join(package_dir, "serve", "server.py")
Expand All @@ -31,14 +30,33 @@ def pull(model: str, quantize: bool = False):
typer.echo("Failed to run model")


@app.command("pull")
def pull(model: str, quantize: bool = False):
if pull_model(model):
convert_to_onnx(model)
if quantize:
convert_to_onnx_q8(model)
print_success_bold(f"Model {model} pulled successfully")
else:
print_error("Failed to get model")


@app.command("list")
def list_models():
if not os.path.exists(get_models_dir()):
print_neutral("No models found")
return
models = os.path.join(get_models_dir(), sanitize_model_name)
for model in os.listdir(models):
typer.echo(model)
models_dir = get_models_dir()
models = []
for model in os.listdir(models_dir):
quantized = (
"✅"
if os.path.exists(os.path.join(models_dir, model, "model_quantized.onnx"))
else ""
)
models.append([model, quantized])
table = tabulate(models, headers=["Name", "Quantized"], tablefmt="grid")
print_neutral(table)


@app.command("remove")
Expand Down
22 changes: 21 additions & 1 deletion infero/pull/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def check_model_integrity(model: str):

if not os.path.exists(model_path):
print_neutral(f"Model {model} not found, downloading...")
download_model(model)
return False

if not os.path.exists(vocab_path) and not os.path.exists(vocab_path_2):
print_neutral(f"Vocab file for {model} not found, downloading...")
Expand Down Expand Up @@ -122,6 +122,26 @@ def download_model(model: str):


def check_model(model: str):

if is_supported(model):
print_success(f"Model {model} is supported")
else:
print_error("Model architecture not supported")

if os.path.exists(
os.path.join(get_package_dir(), f"data/models/{sanitize_model_name(model)}")
):
print_success(f"Model {model} already exists")
chk = check_model_integrity(model)
if chk is True:
return True
else:
print_error(f"Model {model} not found, please run 'infero pull {model}'")
return False


def pull_model(model: str):

if is_supported(model):
print_success(f"Model {model} is supported")
else:
Expand Down
6 changes: 2 additions & 4 deletions infero/pull/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from infero.utils import sanitize_model_name, print_success, print_error
from infero.utils import sanitize_model_name, print_success, print_error, get_models_dir


def remove_model(model):
model_path = os.path.join(
os.getcwd(), "infero/data/models", sanitize_model_name(model)
)
model_path = os.path.join(get_models_dir, sanitize_model_name(model))
if os.path.exists(model_path):
os.rmdir(model_path)
print_success(f"Model {model} removed")
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ transformers = "^4.46.3"
fastapi = {extras = ["standard"], version = "^0.115.5"}
torch = "^2.5.1"
psutil = "^6.1.0"
tabulate = "^0.9.0"



Expand Down

0 comments on commit 2f96713

Please sign in to comment.