Skip to content

Commit

Permalink
Merge pull request #47 from kadirnar/app
Browse files Browse the repository at this point in the history
update web demo
  • Loading branch information
kadirnar committed Apr 12, 2023
2 parents f3b8364 + 3198208 commit eda01d2
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 48 deletions.
2 changes: 1 addition & 1 deletion metaseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict

__version__ = "0.5.2"
__version__ = "0.6.0"
4 changes: 2 additions & 2 deletions metaseg/mask_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def image_predict(
show_image(combined_mask)

if save:
save_image(output_path=output_path, image=combined_mask)
save_image(output_path=output_path, output_image=combined_mask)

return masks

Expand Down Expand Up @@ -180,7 +180,7 @@ def image_predict(

combined_mask = cv2.add(image, mask_image)
if save:
save_image(output_path=output_path, image=combined_mask)
save_image(output_path=output_path, output_image=combined_mask)

if show:
show_image(combined_mask)
Expand Down
7 changes: 6 additions & 1 deletion metaseg/sahi_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

from metaseg import SamPredictor, sam_model_registry
from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask
Expand Down Expand Up @@ -110,7 +112,10 @@ def predict(
plt_load_box(box.cpu().numpy(), plt.gca())
plt.axis("off")
if save:
plt.savefig("output.png")
plt.savefig("output.png", bbox_inches="tight")
output_image = cv2.imread("output.png")
output_image = Image.fromarray(output_image)
return output_image
if show:
plt.show()

Expand Down
188 changes: 145 additions & 43 deletions metaseg/webapp/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import gradio as gr

from metaseg import SegAutoMaskGenerator
from demo import automask_image_app, automask_video_app, sahi_autoseg_app


def image_app():
with gr.Blocks():
with gr.Row():
with gr.Column():
seg_automask_image_file = gr.Image(type="filepath").style(height=260)

with gr.Row():
with gr.Column():
seg_automask_image_model_type = gr.Dropdown(
Expand All @@ -21,34 +19,35 @@ def image_app():
label="Model Type",
)

seg_automask_image_points_per_side = gr.Slider(
minimum=0,
maximum=32,
step=2,
value=16,
label="Points per Side",
)

seg_automask_image_points_per_batch = gr.Slider(
minimum=0,
maximum=64,
step=2,
value=64,
label="Points per Batch",
)

seg_automask_image_min_area = gr.Number(
value=0,
label="Min Area",
)
with gr.Row():
with gr.Column():
seg_automask_image_points_per_side = gr.Slider(
minimum=0,
maximum=32,
step=2,
value=16,
label="Points per Side",
)

seg_automask_image_points_per_batch = gr.Slider(
minimum=0,
maximum=64,
step=2,
value=64,
label="Points per Batch",
)

seg_automask_image_predict = gr.Button(value="Generator")

with gr.Column():
output_image = gr.Image()

seg_automask_image_predict.click(
fn=SegAutoMaskGenerator().save_image,
fn=automask_image_app,
inputs=[
seg_automask_image_file,
seg_automask_image_model_type,
Expand All @@ -65,7 +64,6 @@ def video_app():
with gr.Row():
with gr.Column():
seg_automask_video_file = gr.Video().style(height=260)

with gr.Row():
with gr.Column():
seg_automask_video_model_type = gr.Dropdown(
Expand All @@ -77,34 +75,35 @@ def video_app():
value="vit_l",
label="Model Type",
)

seg_automask_video_points_per_side = gr.Slider(
minimum=0,
maximum=32,
step=2,
value=16,
label="Points per Side",
)
seg_automask_video_points_per_batch = gr.Slider(
minimum=0,
maximum=64,
step=2,
value=64,
label="Points per Batch",
seg_automask_video_min_area = gr.Number(
value=1000,
label="Min Area",
)
with gr.Row():
with gr.Column():
seg_automask_video_min_area = gr.Number(
value=1000,
label="Min Area",
)

with gr.Row():
with gr.Column():
seg_automask_video_points_per_side = gr.Slider(
minimum=0,
maximum=32,
step=2,
value=16,
label="Points per Side",
)

seg_automask_video_points_per_batch = gr.Slider(
minimum=0,
maximum=64,
step=2,
value=64,
label="Points per Batch",
)

seg_automask_video_predict = gr.Button(value="Generator")
with gr.Column():
output_video = gr.Video()

seg_automask_video_predict.click(
fn=SegAutoMaskGenerator().save_video,
fn=automask_video_app,
inputs=[
seg_automask_video_file,
seg_automask_video_model_type,
Expand All @@ -116,6 +115,107 @@ def video_app():
)


def sahi_app():
with gr.Blocks():
with gr.Row():
with gr.Column():
sahi_image_file = gr.Image(type="filepath").style(height=260)
sahi_autoseg_model_type = gr.Dropdown(
choices=[
"vit_h",
"vit_l",
"vit_b",
],
value="vit_l",
label="Sam Model Type",
)

with gr.Row():
with gr.Column():
sahi_model_type = gr.Dropdown(
choices=[
"yolov5",
"yolov8",
],
value="yolov5",
label="Detector Model Type",
)
sahi_image_size = gr.Slider(
minimum=0,
maximum=1600,
step=32,
value=640,
label="Image Size",
)

sahi_overlap_width = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Overlap Width",
)

sahi_slice_width = gr.Slider(
minimum=0,
maximum=640,
step=32,
value=256,
label="Slice Width",
)

with gr.Row():
with gr.Column():
sahi_model_path = gr.Dropdown(
choices=["yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"],
value="yolov5l6.pt",
label="Detector Model Path",
)

sahi_conf_th = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Confidence Threshold",
)
sahi_overlap_height = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Overlap Height",
)
sahi_slice_height = gr.Slider(
minimum=0,
maximum=640,
step=32,
value=256,
label="Slice Height",
)
sahi_image_predict = gr.Button(value="Generator")

with gr.Column():
output_image = gr.Image()

sahi_image_predict.click(
fn=sahi_autoseg_app,
inputs=[
sahi_image_file,
sahi_autoseg_model_type,
sahi_model_type,
sahi_model_path,
sahi_conf_th,
sahi_image_size,
sahi_slice_height,
sahi_slice_width,
sahi_overlap_height,
sahi_overlap_width,
],
outputs=[output_image],
)


def metaseg_app():
app = gr.Blocks()
with app:
Expand All @@ -125,8 +225,10 @@ def metaseg_app():
image_app()
with gr.Tab("Video"):
video_app()
with gr.Tab("SAHI"):
sahi_app()

app.queue(concurrency_count=2)
app.queue(concurrency_count=1)
app.launch(debug=True, enable_queue=True)


Expand Down
91 changes: 91 additions & 0 deletions metaseg/webapp/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from metaseg import SahiAutoSegmentation, SegAutoMaskPredictor, SegManualMaskPredictor, sahi_sliced_predict

# For image


def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
SegAutoMaskPredictor().image_predict(
source=image_path,
model_type=model_type, # vit_l, vit_h, vit_b
points_per_side=points_per_side,
points_per_batch=points_per_batch,
min_area=min_area,
output_path="output.png",
show=False,
save=True,
)
return "output.png"


# For video


def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
SegAutoMaskPredictor().video_predict(
source=video_path,
model_type=model_type, # vit_l, vit_h, vit_b
points_per_side=points_per_side,
points_per_batch=points_per_batch,
min_area=min_area,
output_path="output.mp4",
)
return "output.mp4"


# For manuel box and point selection


def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
SegManualMaskPredictor().image_predict(
source=image_path,
model_type=model_type, # vit_l, vit_h, vit_b
input_point=input_point,
input_label=input_label,
input_box=input_box,
multimask_output=multimask_output,
random_color=random_color,
output_path="output.png",
show=False,
save=True,
)
return "output.png"


# For sahi sliced prediction


def sahi_autoseg_app(
image_path,
sam_model_type,
detection_model_type,
detection_model_path,
conf_th,
image_size,
slice_height,
slice_width,
overlap_height_ratio,
overlap_width_ratio,
):
boxes = sahi_sliced_predict(
image_path=image_path,
detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
detection_model_path=detection_model_path,
conf_th=conf_th,
image_size=image_size,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
)

SahiAutoSegmentation().predict(
source=image_path,
model_type=sam_model_type,
input_box=boxes,
multimask_output=False,
random_color=False,
show=False,
save=True,
)

return "output.png"
Loading

0 comments on commit eda01d2

Please sign in to comment.