Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed pyqt conflict with opencv #39

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,40 @@ Under active development, apologies for rough edges and bugs. Use at your own ri

## Installation

### Pre-processing
1. Install [Segment Anything](https://github.com/facebookresearch/segment-anything) on any machine with a GPU. (Need not be the labelling machine.)
2. Create a conda environment using `conda env create -f environment.yaml` on the labelling machine (Need not have GPU).
3. (Optional) Install [coco-viewer](https://github.com/trsvchn/coco-viewer) to scroll through your annotations quickly.

### Labelling
1. Create a conda environment using `conda env create -f environment.yaml` on the labelling machine (Need not have GPU).
1. (Optional) Install [coco-viewer](https://github.com/trsvchn/coco-viewer) to scroll through your annotations quickly.

## Usage

### (Optional) Create a container with the configured environment

docker run -it -v /home/marcoambrosio/dataset/:/root/dataset --privileged --env=NVIDIA_VISIBLE_DEVICES=all --env=NVIDIA_DRIVER_CAPABILITIES=all --gpus 1 --name salt andreaostuni/salt:salt-cuda-11.8-base /bin/bash

### On the pre-processing machine
1. Setup your dataset in the following format `<dataset_name>/images/*` and create empty folder `<dataset_name>/embeddings`.
- Annotations will be saved in `<dataset_name>/annotations.json` by default.
2. Copy the `helpers` scripts to the base folder of your `segment-anything` folder.
- Call `extract_embeddings.py` to extract embeddings for your images.
- Call `generate_onnx.py` generate `*.onnx` files in models.
4. Copy the models in `models` folder.
5. Symlink your dataset in the SALT's root folder as `<dataset_name>`.
6. Call `segment_anything_annotator.py` with argument `<dataset_name>` and categories `cat1,cat2,cat3..`.
- Call `extract_embeddings.py` to extract embeddings for your images. For example ` python3 extract_embeddings.py --dataset-path <path_to_dataset> `
- Call `generate_onnx.py` generate `*.onnx` files in models. For example ` python3 generate_onnx.py --dataset-path <path_to_dataset> --onnx-models-path <path_to_dataset>/models `

### On the labelling machine
1. Call `segment_anything_annotator.py` with argument `<dataset_name>` and categories `cat1,cat2,cat3..`. For example ` python3 segment_anything_annotator.py --dataset-path <path_to_dataset> --categories cat1,cat2,cat3 `
- There are a few keybindings that make the annotation process fast.
- Click on the object using left clicks and right click (to indicate outside object boundary).
- `n` adds predicted mask into your annotations. (Add button)
- `r` rejects the predicted mask. (Reject button)
- `a` and `d` to cycle through images in your your set. (Next and Prev)
- `l` and `k` to increase and decrease the transparency of the other annotations.
- `Ctrl + S` to save progress to the COCO-style annotations file.
7. [coco-viewer](https://github.com/trsvchn/coco-viewer) to view your annotations.
1. [coco-viewer](https://github.com/trsvchn/coco-viewer) to view your annotations.
- `python cocoviewer.py -i <dataset> -a <dataset>/annotations.json`

1. Call `coco_to_binary_mask.py` with argument `<dataset_name>`. For example ` python3 coco_to_binary_mask.py --dataset-path <path_to_dataset> `. It will create a new folder `masks` in the dataset folder with the binary masks. For now only one binary mask is created with all the segmented regions in the same image. (Multiple categories are not supported yet.)

## Demo

![How it Works Gif!](https://github.com/anuragxel/salt/raw/main/assets/how-it-works.gif)
Expand Down
54 changes: 54 additions & 0 deletions coco_to_binary_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#to put in folder testX
import os
import argparse
import sys


from pycocotools.coco import COCO
from matplotlib import image
from pathlib import Path
import numpy as np
from re import findall

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="./dataset")
args = parser.parse_args()
#TODO: differentiate masks of different categories

dataset_path = args.dataset_path
masks_path = os.path.join(dataset_path, "masks")
if not os.path.exists(masks_path):
os.makedirs(masks_path)
annFile = os.path.join(dataset_path, "annotations.json")

coco = COCO(annFile)

catIds = coco.getCatIds()
imgIds = coco.getImgIds()
annsIds = coco.getAnnIds()

for imgId in imgIds:
img = coco.loadImgs(imgId)[0]
width = coco.imgs[imgId]["width"]
height = coco.imgs[imgId]["height"]
annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
img_id = findall(r'\d+', img["file_name"])[0]

mask = np.zeros((height, width))

try:
mask = np.zeros(coco.annToMask(anns[0]).shape)
for ann in anns:
mask += coco.annToMask(ann)
mask[mask >= 1] = 1
except:
pass

mask_png_name = "mask" +str(img_id) + ".png"
mask_png_path = os.path.join(masks_path, mask_png_name)
image.imsave(mask_png_path, mask, cmap='gray')


5 changes: 4 additions & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ dependencies:
- wheel=0.38.4
- xz=5.2.10
- zlib=1.2.13
- qt
- pyqt
- qtpy
- pip:
- black==23.3.0
- click==8.1.3
Expand All @@ -37,7 +40,7 @@ dependencies:
- networkx==3.1
- numpy==1.24.2
- onnxruntime==1.14.1
- opencv-python==4.7.0.72
- opencv-python-headless==4.7.0.72
- packaging==23.0
- pathspec==0.11.1
- pillow==9.5.0
Expand Down
12 changes: 9 additions & 3 deletions salt/dataset_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
import itertools
import numpy as np
from simplification.cutil import simplify_coords_vwp
from simplification.cutil import simplify_coords_vwp, simplify_coords
import os, cv2, copy
from distinctipy import distinctipy

Expand Down Expand Up @@ -87,8 +87,9 @@ def parse_mask_to_coco(image_id, anno_id, image_mask, category_id, poly=False):
)
if poly == True:
for contour in contours:
sc = simplify_coords_vwp(contour[:,0,:], 2).ravel().tolist()
annotation["segmentation"].append(sc)
sc = contour.ravel().tolist()
if len(sc) > 4:
annotation["segmentation"].append(sc)
return annotation


Expand Down Expand Up @@ -202,3 +203,8 @@ def add_annotation(self, image_id, category_id, mask, poly=True):
def save_annotation(self):
with open(self.coco_json_path, "w") as f:
json.dump(self.coco_json, f)

def update_annotation(self, image_id, category_id, selected_annotations, mask):
for annotation in selected_annotations:
self.coco_json["annotations"][annotation]["category_id"] = category_id

5 changes: 4 additions & 1 deletion salt/display_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def draw_box_on_image(self, image, ann, color):
def draw_annotations(self, image, annotations, colors):
for ann, color in zip(annotations, colors):
image = self.draw_box_on_image(image, ann, color)
mask = self.__convert_ann_to_mask(ann, image.shape[0], image.shape[1])
if type(ann["segmentation"]) is dict:
mask = coco_mask.decode(ann["segmentation"])
else:
mask = self.__convert_ann_to_mask(ann, image.shape[0], image.shape[1])
image = self.overlay_mask_on_image(image, mask, color)
return image

Expand Down
6 changes: 6 additions & 0 deletions salt/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def save_ann(self):
def save(self):
self.dataset_explorer.save_annotation()

def change_category(self, selected_annotations=[]):
self.dataset_explorer.update_annotation(
self.image_id, self.category_id, selected_annotations, self.curr_inputs.curr_mask
)
self.__draw(selected_annotations)

def next_image(self):
if self.image_id == self.dataset_explorer.get_num_images() - 1:
return
Expand Down
24 changes: 19 additions & 5 deletions salt/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,34 @@ def reset(self):
global selected_annotations
self.editor.reset(selected_annotations)
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def add(self):
global selected_annotations
self.editor.save_ann()
self.editor.reset(selected_annotations)
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def next_image(self):
global selected_annotations
self.editor.next_image()
selected_annotations = []
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def prev_image(self):
global selected_annotations
self.editor.prev_image()
selected_annotations = []
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def toggle(self):
global selected_annotations
self.editor.toggle(selected_annotations)
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def transparency_up(self):
global selected_annotations
Expand All @@ -170,6 +175,13 @@ def transparency_down(self):
def save_all(self):
self.editor.save()

def change_category(self):
global selected_annotations
self.editor.change_category(selected_annotations)
self.editor.reset(selected_annotations)
self.graphics_view.imshow(self.editor.display)
self.get_side_panel_annotations()

def get_top_bar(self):
top_bar = QWidget()
button_layout = QHBoxLayout(top_bar)
Expand All @@ -187,6 +199,7 @@ def get_top_bar(self):
"Remove Selected Annotations",
lambda: self.delete_annotations(),
),
("Change Category", lambda: self.change_category()),
]
for button, lmb in buttons:
bt = QPushButton(button)
Expand Down Expand Up @@ -248,29 +261,30 @@ def annotation_list_item_clicked(self, item):
self.graphics_view.imshow(self.editor.display)

def keyPressEvent(self, event):
if event.key() == Qt.Key_Escape:
self.app.quit()
# if event.key() == Qt.Key_Escape:
# self.app.quit()
if event.key() == Qt.Key_A:
self.prev_image()
self.get_side_panel_annotations()
if event.key() == Qt.Key_D:
self.next_image()
self.get_side_panel_annotations()
if event.key() == Qt.Key_K:
self.transparency_down()
if event.key() == Qt.Key_L:
self.transparency_up()
if event.key() == Qt.Key_N:
self.add()
self.get_side_panel_annotations()
if event.key() == Qt.Key_R:
self.reset()
if event.key() == Qt.Key_T:
self.toggle()
if event.key() == Qt.Key_C:
self.change_category()
if event.modifiers() == Qt.ControlModifier and event.key() == Qt.Key_S:
self.save_all()
elif event.key() == Qt.Key_Space:
print("Space pressed")
# self.clear_annotations(selected_annotations)
# Do something if the space bar is pressed
# pass