Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Dev #35

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,6 @@ dmypy.json
# Dataset
dataset/
models/
flyer_pages/
flyer_pages/

.idea/
2 changes: 1 addition & 1 deletion salt/dataset_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_image_data(self, image_id):
image_bgr = copy.deepcopy(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_embedding = np.load(embedding_path)
return image, image_bgr, image_embedding
return image, image_bgr, image_embedding, image_name

def __add_to_our_annotation_dict(self, annotation):
image_id = annotation["image_id"]
Expand Down
114 changes: 98 additions & 16 deletions salt/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@ def reset_inputs(self):
self.input_label = np.array([])
self.low_res_logits = None
self.curr_mask = None
self.paint_mask = None
self.curr_point_mask = None

def set_mask(self, mask):
self.curr_mask = mask
self.curr_point_mask = mask

def add_paint_mask(self, point_x, point_y):
if self.paint_mask is None:
self.paint_mask = np.zeros(self.curr_mask_shape)

self.paint_mask[point_y - 3:point_y + 3, point_x - 3:point_x + 3] = 1

def era_paint_mask(self, point_x, point_y):
if self.paint_mask is None:
self.paint_mask = np.zeros(self.curr_mask_shape)
self.paint_mask[point_y - 3:point_y + 3, point_x - 3:point_x + 3] = -1

def add_input_click(self, input_point, input_label):
if len(self.input_point) == 0:
Expand All @@ -31,10 +44,13 @@ def add_input_click(self, input_point, input_label):
def set_low_res_logits(self, low_res_logits):
self.low_res_logits = low_res_logits

def set_xy(self, xy):
self.curr_mask_shape = xy


class Editor:
def __init__(
self, onnx_models_path, dataset_path, categories=None, coco_json_path=None
self, onnx_models_path, dataset_path, categories=None, coco_json_path=None
):
self.dataset_path = dataset_path
self.coco_json_path = coco_json_path
Expand All @@ -56,7 +72,9 @@ def __init__(
self.image,
self.image_bgr,
self.image_embedding,
self.name,
) = self.dataset_explorer.get_image_data(self.image_id)
self.curr_inputs.set_xy(self.image.shape[:2])
self.display = self.image_bgr.copy()
self.onnx_helper = OnnxModels(
onnx_models_path,
Expand All @@ -75,6 +93,45 @@ def list_annotations(self):
def delete_annotations(self, annotation_id):
self.dataset_explorer.delete_annotations(self.image_id, annotation_id)

def __draw(self, selected_annotations=[]):
self.display = self.image_bgr.copy()
if self.curr_inputs.paint_mask is not None and self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.paint_mask + self.curr_inputs.curr_point_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)

elif self.curr_inputs.paint_mask is not None:
tmp_combination = self.curr_inputs.paint_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)

elif self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.curr_point_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)
# if self.curr_inputs.curr_mask is not None:
# # self.display = self.du.draw_points(
# # self.display, self.curr_inputs.input_point, self.curr_inputs.input_label)
# self.display = self.du.overlay_mask_on_image(self.display, self.curr_inputs.curr_mask)

if self.show_other_anns:
self.__draw_known_annotations(selected_annotations)

def online_draw(self):
self.display = self.image_bgr.copy()
if self.curr_inputs.paint_mask is not None and self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.paint_mask + self.curr_inputs.curr_point_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)

elif self.curr_inputs.paint_mask is not None:
tmp_combination = self.curr_inputs.paint_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)

elif self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.curr_point_mask
self.display = self.du.overlay_mask_on_image(self.display, tmp_combination)
# if self.curr_inputs.curr_mask is not None:
# # self.display = self.du.draw_points(
# # self.display, self.curr_inputs.input_point, self.curr_inputs.input_label)
# self.display = self.du.overlay_mask_on_image(self.display, self.curr_inputs.curr_mask)

def __draw_known_annotations(self, selected_annotations=[]):
anns, colors = self.dataset_explorer.get_annotations(
self.image_id, return_colors=True
Expand All @@ -86,18 +143,6 @@ def __draw_known_annotations(self, selected_annotations=[]):
# Use this to list the annotations
self.display = self.du.draw_annotations(self.display, anns, colors)

def __draw(self, selected_annotations=[]):
self.display = self.image_bgr.copy()
if self.curr_inputs.curr_mask is not None:
self.display = self.du.draw_points(
self.display, self.curr_inputs.input_point, self.curr_inputs.input_label
)
self.display = self.du.overlay_mask_on_image(
self.display, self.curr_inputs.curr_mask
)
if self.show_other_anns:
self.__draw_known_annotations(selected_annotations)

def add_click(self, new_pt, new_label, selected_annotations=[]):
self.curr_inputs.add_input_click(new_pt, new_label)
masks, low_res_logits = self.onnx_helper.call(
Expand All @@ -106,7 +151,8 @@ def add_click(self, new_pt, new_label, selected_annotations=[]):
self.curr_inputs.input_point,
self.curr_inputs.input_label,
low_res_logits=self.curr_inputs.low_res_logits,
)
) # masks only True False

self.curr_inputs.set_mask(masks[0, 0, :, :])
self.curr_inputs.set_low_res_logits(low_res_logits)
self.__draw(selected_annotations)
Expand Down Expand Up @@ -136,9 +182,28 @@ def draw_selected_annotations(self, selected_annotations=[]):
self.__draw(selected_annotations)

def save_ann(self):
if self.curr_inputs.paint_mask is not None and self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.paint_mask + self.curr_inputs.curr_point_mask
tmp_combination[tmp_combination > 0] = True
tmp_combination[tmp_combination < 0] = False

elif self.curr_inputs.paint_mask is not None:
tmp_combination = self.curr_inputs.paint_mask
tmp_combination[tmp_combination > 0] = True
tmp_combination[tmp_combination < 0] = False

elif self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.curr_point_mask

else:
tmp_combination = None

self.dataset_explorer.add_annotation(
self.image_id, self.category_id, self.curr_inputs.curr_mask
self.image_id, self.category_id, tmp_combination
)
# self.dataset_explorer.add_annotation(
# self.image_id, self.category_id, self.curr_inputs.curr_point_mask
# )

def save(self):
self.dataset_explorer.save_annotation()
Expand All @@ -151,7 +216,21 @@ def next_image(self):
self.image,
self.image_bgr,
self.image_embedding,
self.name,
) = self.dataset_explorer.get_image_data(self.image_id)
self.display = self.image_bgr.copy()
self.onnx_helper.set_image_resolution(self.image.shape[1], self.image.shape[0])
self.reset()

def jump2image(self, image_id):
self.image_id = image_id - 1
(
self.image,
self.image_bgr,
self.image_embedding,
self.name,
) = self.dataset_explorer.get_image_data(self.image_id)

self.display = self.image_bgr.copy()
self.onnx_helper.set_image_resolution(self.image.shape[1], self.image.shape[0])
self.reset()
Expand All @@ -164,6 +243,7 @@ def prev_image(self):
self.image,
self.image_bgr,
self.image_embedding,
self.name,
) = self.dataset_explorer.get_image_data(self.image_id)
self.display = self.image_bgr.copy()
self.onnx_helper.set_image_resolution(self.image.shape[1], self.image.shape[0])
Expand All @@ -189,3 +269,5 @@ def get_categories(self, get_colors=False):
def select_category(self, category_name):
category_id = self.categories.index(category_name)
self.category_id = category_id


98 changes: 82 additions & 16 deletions salt/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
QWidget,
QLabel,
QRadioButton,
QComboBox,
)

selected_annotations = []
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, editor):
self.setScene(self.scene)

self.image_item = None
self.flag = False

def set_image(self, q_img):
pixmap = QPixmap.fromImage(q_img)
Expand All @@ -60,17 +62,17 @@ def set_image(self, q_img):
self.setSceneRect(QRectF(pixmap.rect()))

def wheelEvent(self, event: QWheelEvent):
modifiers = QApplication.keyboardModifiers()
if modifiers == Qt.ControlModifier:
adj = (event.angleDelta().y() / 120) * 0.1
self.scale(1 + adj, 1 + adj)
zoom_in_factor = 1.25
zoom_out_factor = 1 / zoom_in_factor
old_pos = self.mapToScene(event.pos())
if event.angleDelta().y() > 0:
zoom_factor = zoom_in_factor
else:
delta_y = event.angleDelta().y()
delta_x = event.angleDelta().x()
x = self.horizontalScrollBar().value()
self.horizontalScrollBar().setValue(x - delta_x)
y = self.verticalScrollBar().value()
self.verticalScrollBar().setValue(y - delta_y)
zoom_factor = zoom_out_factor
self.scale(zoom_factor, zoom_factor)
new_pos = self.mapToScene(event.pos())
delta = new_pos - old_pos
self.translate(delta.x(), delta.y())

def imshow(self, img):
height, width, channel = img.shape
Expand All @@ -90,12 +92,39 @@ def mousePressEvent(self, event: QMouseEvent) -> None:
pos = event.pos()
pos_in_item = self.mapToScene(pos) - self.image_item.pos()
x, y = pos_in_item.x(), pos_in_item.y()
if event.button() == Qt.LeftButton:
label = 1
elif event.button() == Qt.RightButton:
label = 0
self.editor.add_click([int(x), int(y)], label, selected_annotations)
self.imshow(self.editor.display)
if self.mode == 'point':
self.flag = False
if event.button() == Qt.LeftButton:
label = 1
elif event.button() == Qt.RightButton:
label = 0
self.editor.add_click([int(x), int(y)], label)
elif self.mode == 'paint':
self.flag = True
self.editor.curr_inputs.add_paint_mask(int(x), int(y))
elif self.mode == 'eraser':
self.flag = True
self.editor.curr_inputs.era_paint_mask(int(x), int(y))
self.editor.online_draw()
self.imshow(self.editor.display)

def mouseMoveEvent(self, event: QMouseEvent) -> None:
pos = event.pos()
pos_in_item = self.mapToScene(pos) - self.image_item.pos()
x, y = pos_in_item.x(), pos_in_item.y()
if self.flag:
if self.mode == 'paint':
self.editor.curr_inputs.add_paint_mask(int(x), int(y))
elif self.mode == 'eraser':
self.editor.curr_inputs.era_paint_mask(int(x), int(y))
self.editor.online_draw()
self.imshow(self.editor.display)

def mouseReleaseEvent(self, event: QMouseEvent) -> None:
self.flag = False

def update_PPE_mode(self, mode):
self.mode = mode


class ApplicationInterface(QWidget):
Expand Down Expand Up @@ -126,9 +155,15 @@ def __init__(self, app, editor, panel_size=(1920, 1080)):

self.layout.addLayout(self.main_window)

self.label = QLabel()
self.label.resize(200, 100)
self.label.setText(f'{self.editor.name} ... 1/{self.editor.dataset_explorer.get_num_images()}')
self.layout.addWidget(self.label)

self.setLayout(self.layout)

self.graphics_view.imshow(self.editor.display)
self.execute_mode()

def reset(self):
global selected_annotations
Expand All @@ -144,12 +179,14 @@ def add(self):
def next_image(self):
global selected_annotations
self.editor.next_image()
self._update_label(self.editor.name, self.editor.image_id)
selected_annotations = []
self.graphics_view.imshow(self.editor.display)

def prev_image(self):
global selected_annotations
self.editor.prev_image()
self._update_label(self.editor.name, self.editor.image_id)
selected_annotations = []
self.graphics_view.imshow(self.editor.display)

Expand All @@ -169,6 +206,7 @@ def transparency_down(self):

def save_all(self):
self.editor.save()
self._update_label(self.editor.name, self.editor.image_id)

def get_top_bar(self):
top_bar = QWidget()
Expand All @@ -193,6 +231,18 @@ def get_top_bar(self):
bt.clicked.connect(lmb)
button_layout.addWidget(bt)

self.box = QComboBox(top_bar)
self.box.addItems([str(x + 1) for x in range(self.editor.dataset_explorer.get_num_images())])
self.box.currentIndexChanged.connect(self.jump2slice)

self.point_paint_era = QComboBox(top_bar)
self.point_paint_era.addItems(['point', 'paint', 'eraser'])
self.point_paint_era.setCurrentIndex(0)
self.point_paint_era.currentIndexChanged.connect(self.execute_mode)

button_layout.addWidget(self.box)
button_layout.addWidget(self.point_paint_era)

return top_bar

def get_side_panel(self):
Expand Down Expand Up @@ -247,6 +297,22 @@ def annotation_list_item_clicked(self, item):
self.editor.draw_selected_annotations(selected_annotations)
self.graphics_view.imshow(self.editor.display)

def _update_label(self, name, image_id):
self.label.setText(f'{name} ... {image_id + 1}/{self.editor.dataset_explorer.get_num_images()}')
self.layout.addWidget(self.label)
self.setLayout(self.layout)

def jump2slice(self):
self.editor.jump2image(int(self.box.currentText()))
self._update_label(self.editor.name, self.editor.image_id)
self.graphics_view.imshow(self.editor.display)

def execute_mode(self):
# Here is for change point, paint, eraser mode
self.graphics_view.update_PPE_mode(self.point_paint_era.currentText())
# print(self.point_paint_era.currentText())
# pass

def keyPressEvent(self, event):
if event.key() == Qt.Key_Escape:
self.app.quit()
Expand Down
2 changes: 1 addition & 1 deletion salt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def apply_coords(coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarr
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
return coords