Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx script fix and added as main package #73

Merged
merged 3 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metaseg/sahi_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from PIL import Image

from metaseg import SamPredictor, sam_model_registry
from metaseg.generator import SamPredictor, sam_model_registry
from metaseg.utils import (
download_model,
load_image,
Expand Down
1 change: 1 addition & 0 deletions metaseg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .data_utils import load_image as load_image
from .data_utils import load_mask as load_mask
from .data_utils import load_server_image as load_server_image
from .data_utils import load_video as load_video
onuralpszr marked this conversation as resolved.
Show resolved Hide resolved
from .data_utils import multi_boxes as multi_boxes
from .data_utils import plt_load_box as plt_load_box
from .data_utils import plt_load_mask as plt_load_mask
Expand Down
590 changes: 586 additions & 4 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pillow = "^9.5.0"
pycocotools = "^2.0.6"
onnx = "^1.14.0"
onnxruntime = "^1.15.1"
fal-serverless = "^0.6.35"



Expand Down
283 changes: 278 additions & 5 deletions requirements-dev.txt

Large diffs are not rendered by default.

354 changes: 349 additions & 5 deletions requirements.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion scripts/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from typing import Any, Dict, List

import cv2 # type: ignore
import cv2

from metaseg.generator import SamAutomaticMaskGenerator, sam_model_registry

Expand Down
46 changes: 18 additions & 28 deletions scripts/export_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
import warnings

import torch
from onnxruntime import InferenceSession
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

from metaseg import build_sam, build_sam_vit_b, build_sam_vit_l
from metaseg.generator import build_sam, build_sam_vit_b, build_sam_vit_l
from metaseg.utils.onnx import SamOnnxModel

try:
import onnxruntime # type: ignore

onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False

parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
Expand Down Expand Up @@ -169,11 +165,10 @@ def run_export(
dynamic_axes=dynamic_axes,
)

if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = onnxruntime.InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")


def to_numpy(tensor):
Expand All @@ -193,18 +188,13 @@ def to_numpy(tensor):
return_extra_metrics=args.return_extra_metrics,
)

if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore

print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")