Skip to content

Commit

Permalink
Merge pull request #50 from kadirnar/add_manuel_mask_video
Browse files Browse the repository at this point in the history
add video support for manuel_mask
  • Loading branch information
kadirnar authored Apr 12, 2023
2 parents eda01d2 + 5ef4212 commit e40b10e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ results = SegAutoMaskPredictor().video_predict(

# For manuel box and point selection

# For image
results = SegManualMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
Expand All @@ -61,6 +62,20 @@ results = SegManualMaskPredictor().image_predict(
show=True,
save=False,
)

# For video

results = SegManualMaskPredictor().video_predict(
source="test.mp4",
model_type="vit_l", # vit_l, vit_h, vit_b
input_point=[0, 0, 100, 100]
input_label=N
input_box=None,
multimask_output=False,
random_color=False,
output_path="output.mp4",
)
```
```
### SAHI + Segment Anything
Expand Down
2 changes: 1 addition & 1 deletion metaseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict

__version__ = "0.6.0"
__version__ = "0.6.1"
84 changes: 84 additions & 0 deletions metaseg/mask_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,87 @@ def image_predict(
show_image(combined_mask)

return masks

def video_predict(
self,
source,
model_type,
input_box=None,
input_point=None,
input_label=None,
multimask_output=False,
output_path="output.mp4",
random_color=False,
):
cap, out = load_video(source, output_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

for _ in tqdm(range(length)):
ret, frame = cap.read()
if not ret:
break

model = self.load_model(model_type)
predictor = SamPredictor(model)
predictor.set_image(frame)

if type(input_box[0]) == list:
input_boxes, new_boxes = multi_boxes(input_box, predictor, frame)

masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=new_boxes,
multimask_output=False,
)
for mask in masks:
mask_image = load_mask(mask.cpu().numpy(), random_color)

for box in input_boxes:
frame = load_box(box.cpu().numpy(), frame)

elif type(input_box[0]) == int:
input_boxes = np.array(input_box)[None, :]

masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_boxes,
multimask_output=multimask_output,
)
mask_image = load_mask(masks, random_color)
frame = load_box(input_box, frame)

combined_mask = cv2.add(frame, mask_image)
out.write(combined_mask)

out.release()
cap.release()
cv2.destroyAllWindows()
return output_path

if __name__ == "__main__":
# video
source = "test.mp4"
model_type = "sam_resnet50d_ade20k"
input_box = [0, 0, 100, 100]
input_point = None
input_label = None
multimask_output = False
output_path = "output.mp4"
random_color = False
show = False
save = True

# video
predictor = SegManualMaskPredictor()
predictor.video_predict(
source,
model_type,
input_box,
input_point,
input_label,
multimask_output,
output_path,
random_color,
)

0 comments on commit e40b10e

Please sign in to comment.