Skip to content

Commit

Permalink
fix: 🛠️ import movement and function clean up #80
Browse files Browse the repository at this point in the history
  • Loading branch information
onuralpszr committed Jun 27, 2023
2 parents e1ca934 + f040b1c commit d1c2705
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 39 deletions.
5 changes: 2 additions & 3 deletions metaseg/mask_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
load_mask,
load_video,
multi_boxes,
save_image,
show_image,
)

Expand Down Expand Up @@ -77,7 +76,7 @@ def image_predict(
show_image(combined_mask)

if save:
save_image(output_path=output_path, output_image=combined_mask)
cv2.imwrite(output_path, combined_mask)

return masks

Expand Down Expand Up @@ -202,7 +201,7 @@ def image_predict(

combined_mask = cv2.add(image, mask_image)
if save:
save_image(output_path=output_path, output_image=combined_mask)
cv2.imwrite(output_path, combined_mask)

if show:
show_image(combined_mask)
Expand Down
5 changes: 2 additions & 3 deletions metaseg/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple

import cv2
import numpy as np
import torch
from pycocotools import mask as mask_utils


class MaskData:
Expand Down Expand Up @@ -271,7 +273,6 @@ def remove_small_regions(
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore

assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
Expand All @@ -292,8 +293,6 @@ def remove_small_regions(


def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
from pycocotools import mask as mask_utils # type: ignore

h, w = uncompressed_rle["size"]
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
Expand Down
46 changes: 13 additions & 33 deletions metaseg/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
def load_image(image_path):
import cv2
from io import BytesIO
from os import system
from uuid import uuid4

import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torch import tensor


def load_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image


def load_server_image(image_path):
import os
from io import BytesIO
from uuid import uuid4

from PIL import Image

imagedir = str(uuid4())
os.system(f"mkdir -p {imagedir}")
system(f"mkdir -p {imagedir}")
image = Image.open(BytesIO(image_path))
if image.mode != "RGB":
image = image.convert("RGB")
Expand All @@ -26,29 +29,22 @@ def load_server_image(image_path):


def load_video(video_path, output_path="output.mp4"):
import cv2

cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"XVID")
fps = int(cap.get(cv2.CAP_PROP_FPS))
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

return cap, out


def read_image(image_path):
import cv2

image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image


def load_mask(mask, random_color):
import numpy as np

if random_color:
color = np.random.rand(3) * 255
else:
Expand All @@ -61,16 +57,12 @@ def load_mask(mask, random_color):


def load_box(box, image):
import cv2

x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2)
return image


def plt_load_mask(mask, ax, random_color=False):
import numpy as np

if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
Expand All @@ -81,8 +73,6 @@ def plt_load_mask(mask, ax, random_color=False):


def plt_load_box(box, ax):
import matplotlib.pyplot as plt

x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
Expand All @@ -91,24 +81,14 @@ def plt_load_box(box, ax):


def multi_boxes(boxes, predictor, image):
import torch

input_boxes = torch.tensor(boxes, device=predictor.device)
input_boxes = tensor(boxes, device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(
input_boxes, image.shape[:2]
)
return input_boxes, transformed_boxes


def show_image(output_image):
import cv2

cv2.imshow("output", output_image)
cv2.waitKey(0)
cv2.destroyAllWindows()


def save_image(output_image, output_path):
import cv2

cv2.imwrite(output_path, output_image)

0 comments on commit d1c2705

Please sign in to comment.