Skip to content

Commit

Permalink
feat: Image annotation shapes improvements #1860 (#2263)
Browse files Browse the repository at this point in the history
  • Loading branch information
marek-mihok authored Feb 13, 2024
1 parent 0c2ac5c commit fc273d9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 56 deletions.
48 changes: 47 additions & 1 deletion ui/src/image_annotator.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import { act, fireEvent, render, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import React from 'react'
import { ImageAnnotator, XImageAnnotator } from './image_annotator'
import { ImageAnnotator, XImageAnnotator, ZOOM_STEP } from './image_annotator'
import { wave } from './ui'

const
Expand Down Expand Up @@ -1123,6 +1123,52 @@ describe('ImageAnnotator.tsx', () => {
expect(wave.args[name]).toMatchObject([{ tag: 'person', shape: { rect: { x1: 19, x2: 109, y1: 19, y2: 109 } } }, polygon])
})

it('Shows correct cursor when hovering over rect corner while the image is zoomed', async () => {
const { container } = render(<XImageAnnotator model={model} />)
await waitForLoad(container)
const canvasEl = container.querySelector('canvas') as HTMLCanvasElement

fireEvent.click(canvasEl, { clientX: 50, clientY: 50 })
fireEvent.mouseMove(canvasEl, { clientX: 10, clientY: 10 })
expect(canvasEl.style.cursor).toBe('nwse-resize')

// Zoom-in 3 steps.
fireEvent.wheel(canvasEl, { deltaY: -1 })
fireEvent.wheel(canvasEl, { deltaY: -1 })
fireEvent.wheel(canvasEl, { deltaY: -1 })

// HACK: Move to the same position to reset grab cursor which appears after zoom-in.
fireEvent.mouseMove(canvasEl, { clientX: 10, clientY: 10 })
expect(canvasEl.style.cursor).not.toBe('nwse-resize')

const zoomFactor = 1 + ZOOM_STEP * 3
fireEvent.mouseMove(canvasEl, { clientX: 10 * zoomFactor, clientY: 10 * zoomFactor })
expect(canvasEl.style.cursor).toBe('nwse-resize')
})

it('Shows correct cursor when hovering over polygon corner while the image is zoomed', async () => {
const { container } = render(<XImageAnnotator model={model} />)
await waitForLoad(container)
const canvasEl = container.querySelector('canvas') as HTMLCanvasElement

fireEvent.click(canvasEl, { clientX: 180, clientY: 120 })
fireEvent.mouseMove(canvasEl, { clientX: 105, clientY: 100 })
expect(canvasEl.style.cursor).toBe('move')

// Zoom-in 3 steps.
fireEvent.wheel(canvasEl, { deltaY: -1 })
fireEvent.wheel(canvasEl, { deltaY: -1 })
fireEvent.wheel(canvasEl, { deltaY: -1 })

// HACK: Move to the same position to reset grab cursor which appears after zoom-in.
fireEvent.mouseMove(canvasEl, { clientX: 105, clientY: 100 })
expect(canvasEl.style.cursor).not.toBe('move')

const zoomFactor = 1 + ZOOM_STEP * 3
fireEvent.mouseMove(canvasEl, { clientX: 105 * zoomFactor, clientY: 100 * zoomFactor })
expect(canvasEl.style.cursor).toBe('move')
})

// TODO: Add polygon version of this test.
it('Moves rect correctly when the image is zoomed - 2 zoom steps', async () => {
const { container } = render(<XImageAnnotator model={model} />)
Expand Down
28 changes: 15 additions & 13 deletions ui/src/image_annotator.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ export type Position = {
export type DrawnShape = ImageAnnotatorItem & { isFocused?: B, boundaryRect?: ImageAnnotatorRect | null }
export type DrawnPoint = ImageAnnotatorPoint & { isAux?: B }

export const ZOOM_STEP = 0.15

const
ZOOM_STEP = 0.15,
tableBorderStyle = `0.5px solid ${cssVar('$neutralTertiaryAlt')}`,
css = stylesheet({
title: {
Expand Down Expand Up @@ -202,17 +203,17 @@ const
},
eventToCursor = (e: React.MouseEvent, rect: DOMRect, zoom: F, position: ImageAnnotatorPoint) =>
({ cursor_x: (e.clientX - rect.left - position.x) / zoom, cursor_y: (e.clientY - rect.top - position.y) / zoom }),
getIntersectedShape = (shapes: DrawnShape[], cursor_x: F, cursor_y: F) => shapes.find(({ shape, isFocused }) => {
if (shape.rect) return isIntersectingRect(cursor_x, cursor_y, shape.rect, isFocused)
if (shape.polygon) return isIntersectingPolygon({ x: cursor_x, y: cursor_y }, shape.polygon.vertices, isFocused)
getIntersectedShape = (shapes: DrawnShape[], cursor_x: F, cursor_y: F, zoom: F) => shapes.find(({ shape, isFocused }) => {
if (shape.rect) return isIntersectingRect(cursor_x, cursor_y, shape.rect, isFocused, zoom)
if (shape.polygon) return isIntersectingPolygon({ x: cursor_x, y: cursor_y }, shape.polygon.vertices, isFocused, zoom)
}),
getCorrectCursorNonDragging = (cursorX: F, cursorY: F, shapes: DrawnShape[], isSelect = true) => {
getCorrectCursorNonDragging = (cursorX: F, cursorY: F, shapes: DrawnShape[], zoom: F, isSelect = true) => {
if (!isSelect) return 'crosshair'
// This is an expensive operation, so we only do it if we're not dragging to prevent rendering jank.
const intersected = getIntersectedShape(shapes, cursorX, cursorY)
const intersected = getIntersectedShape(shapes, cursorX, cursorY, zoom)

if (intersected?.isFocused && intersected.shape.rect) return getRectCornerCursor(intersected.shape.rect, cursorX, cursorY) || 'move'
else if (intersected?.isFocused && intersected.shape.polygon) return getPolygonPointCursor(intersected.shape.polygon.vertices, cursorX, cursorY) || 'move'
if (intersected?.isFocused && intersected.shape.rect) return getRectCornerCursor(intersected.shape.rect, cursorX, cursorY, zoom) || 'move'
else if (intersected?.isFocused && intersected.shape.polygon) return getPolygonPointCursor(intersected.shape.polygon.vertices, cursorX, cursorY, zoom) || 'move'
return intersected ? 'pointer' : 'auto'
},
getCorrectCursorDragging = (rectRef: RectAnnotator | null, polygonRef: PolygonAnnotator | null, isSelect = false) => {
Expand Down Expand Up @@ -309,7 +310,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {
changeActiveShape = (shape: keyof ImageAnnotatorShape | 'select') => {
if (canvasRef.current) {
const { x, y } = mousePositionRef.current
canvasRef.current.style.cursor = getCorrectCursorNonDragging(x, y, drawnShapes, shape === 'select')
canvasRef.current.style.cursor = getCorrectCursorNonDragging(x, y, drawnShapes, zoom, shape === 'select')
}
setActiveShape(shape)
if (model.events?.includes('tool_change')) wave.emit(model.name, 'tool_change', shape)
Expand All @@ -335,6 +336,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {
imgPositionRef.current = { x, y }
canvasCtx.setTransform(zoom, 0, 0, zoom, imgPositionRef.current.x, imgPositionRef.current.y)
imgRef.current.style.transform = `translate(${imgPositionRef.current.x}px, ${imgPositionRef.current.y}px) scale(${zoom})`
if (canvasCtxRef?.current) canvasCtxRef.current.lineWidth = 2 / zoom
redrawExistingShapes()
},
resetShapeCreation = () => {
Expand Down Expand Up @@ -367,7 +369,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {

const
{ cursor_x, cursor_y } = eventToCursor(e, canvas.getBoundingClientRect(), zoom, imgPositionRef.current),
intersected = getIntersectedShape(drawnShapes, cursor_x, cursor_y)
intersected = getIntersectedShape(drawnShapes, cursor_x, cursor_y, zoom)

if (e.buttons !== 1 && !intersected?.shape.polygon) return // Ignore right-click.

Expand Down Expand Up @@ -416,7 +418,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {
} else {
canvas.style.cursor = clickStartPosition?.dragging
? getCorrectCursorDragging(rectRef.current, polygonRef.current, isSelect)
: getCorrectCursorNonDragging(cursor_x, cursor_y, drawnShapes, isSelect)
: getCorrectCursorNonDragging(cursor_x, cursor_y, drawnShapes, zoom, isSelect)
}

switch (activeShape) {
Expand Down Expand Up @@ -472,7 +474,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {
start = clickStartPositionRef.current,
rect = canvas.getBoundingClientRect(),
{ cursor_x, cursor_y } = eventToCursor(e, rect, zoom, imgPositionRef.current),
intersected = getIntersectedShape(drawnShapes, cursor_x, cursor_y)
intersected = getIntersectedShape(drawnShapes, cursor_x, cursor_y, zoom)

if (model.events?.includes('click') && activeShape !== 'select' && start) {
wave.emit(model.name, 'click', {
Expand Down Expand Up @@ -537,7 +539,7 @@ export const XImageAnnotator = ({ model }: { model: ImageAnnotator }) => {
}

clickStartPositionRef.current = undefined
canvas.style.cursor = getCorrectCursorNonDragging(cursor_x, cursor_y, drawnShapes)
canvas.style.cursor = getCorrectCursorNonDragging(cursor_x, cursor_y, drawnShapes, zoom)
},
moveAllSelectedShapes = (dx: U, dy: U) => {
drawnShapes.forEach(s => {
Expand Down
42 changes: 20 additions & 22 deletions ui/src/image_annotator_polygon.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { F, S, U } from './core'
import { DrawnPoint, DrawnShape, ImageAnnotatorPoint, ImageAnnotatorRect } from "./image_annotator"
import { ARC_RADIUS } from "./image_annotator_rect"
import { ARC_RADIUS_DEFAULT } from "./image_annotator_rect"

export class PolygonAnnotator {
private currPolygonPoints: ImageAnnotatorPoint[] = []
private boundaryRect: ImageAnnotatorRect | null = null
private draggedPoint: DrawnPoint | null = null
private draggedShape: DrawnShape | null = null
private getArcRadius = () => this.ctx ? ARC_RADIUS_DEFAULT / this.ctx.getTransform().a : ARC_RADIUS_DEFAULT


constructor(private canvas: HTMLCanvasElement, private ctx: CanvasRenderingContext2D | null) { }
Expand Down Expand Up @@ -102,7 +103,7 @@ export class PolygonAnnotator {

onMouseDown(cursor_x: U, cursor_y: U, shape: DrawnShape) {
if (!shape.shape.polygon) return
this.draggedPoint = shape.shape.polygon.vertices.find(p => isIntersectingPoint(p, cursor_x, cursor_y)) || null
this.draggedPoint = shape.shape.polygon.vertices.find(p => isIntersectingPoint(p, cursor_x, cursor_y, this.getArcRadius())) || null
this.draggedShape = shape
}

Expand All @@ -126,15 +127,15 @@ export class PolygonAnnotator {
}

tryToAddAuxPoint = (cursor_x: F, cursor_y: F, items: DrawnPoint[]) => {
const clickedPoint = items.find(p => isIntersectingPoint(p, cursor_x, cursor_y))
const clickedPoint = items.find(p => isIntersectingPoint(p, cursor_x, cursor_y, this.getArcRadius()))
if (clickedPoint?.isAux) {
clickedPoint.isAux = false
return true
}
}

tryToRemovePoint = (cursor_x: F, cursor_y: F, items: DrawnPoint[]) => {
return items.filter(p => !isIntersectingPoint(p, cursor_x, cursor_y))
return items.filter(p => !isIntersectingPoint(p, cursor_x, cursor_y, this.getArcRadius()))
}

getPolygonPointsWithAux = (points: DrawnPoint[]) => {
Expand Down Expand Up @@ -175,7 +176,7 @@ export class PolygonAnnotator {
this.ctx.stroke()
}

drawPolygon = (points: DrawnPoint[], color: S, joinLastPoint = true, isFocused = false) => {
drawPolygon = (points: DrawnPoint[], color: S, joinLastPoint = true, isFocused = false, preview = false) => {
if (!points.length || !this.ctx) return

this.ctx.fillStyle = color
Expand All @@ -187,17 +188,15 @@ export class PolygonAnnotator {

_points.forEach(({ x, y }) => this.drawLine(x, y))
if (joinLastPoint) this.drawLine(points[0].x, points[0].y)
if (isFocused) {
this.ctx.fillStyle = color.substring(0, color.length - 2) + '0.2)'
this.ctx.fill()
_points.forEach(({ x, y, isAux }) => this.drawPoint(x, y, isAux))
}
this.ctx.fillStyle = color.substring(0, color.length - 2) + '0.2)'
if (!preview) this.ctx.fill()
if (isFocused) _points.forEach(({ x, y, isAux }) => this.drawPoint(x, y, color, isAux))
}

drawPreviewLine = (cursor_x: F, cursor_y: F, color: S) => {
if (!this.ctx || !this.currPolygonPoints.length) return

this.drawPolygon(this.currPolygonPoints, color, false)
this.drawPolygon(this.currPolygonPoints, color, false, false, true)
const { x, y } = this.currPolygonPoints[this.currPolygonPoints.length - 1]

this.ctx.beginPath()
Expand All @@ -207,28 +206,28 @@ export class PolygonAnnotator {
this.ctx.stroke()
}

drawPoint = (x: F, y: F, isAux = false) => {
drawPoint = (x: F, y: F, color: S, isAux = false,) => {
if (!this.ctx) return

const path = new Path2D()
path.arc(x, y, ARC_RADIUS, 0, 2 * Math.PI)
this.ctx.strokeStyle = isAux ? '#5e5c5c' : '#000'
path.arc(x, y, this.getArcRadius(), 0, 2 * Math.PI)
this.ctx.strokeStyle = isAux ? color.substring(0, color.length - 2) + '0.5)' : color
this.ctx.fillStyle = isAux ? '#b8b8b8' : '#FFF'
this.ctx.fill(path)
this.ctx.stroke(path)
}

isIntersectingFirstPoint = (cursor_x: F, cursor_y: F) => {
if (!this.currPolygonPoints.length) return false
return isIntersectingPoint(this.currPolygonPoints[0], cursor_x, cursor_y)
return isIntersectingPoint(this.currPolygonPoints[0], cursor_x, cursor_y, this.getArcRadius())
}
}

// Credit: https://gist.github.com/vlasky/d0d1d97af30af3191fc214beaf379acc?permalink_comment_id=3658988#gistcomment-3658988
const cross = (x: ImageAnnotatorPoint, y: ImageAnnotatorPoint, z: ImageAnnotatorPoint) => (y.x - x.x) * (z.y - x.y) - (z.x - x.x) * (y.y - x.y)
export
const isIntersectingPolygon = (p: ImageAnnotatorPoint, points: ImageAnnotatorPoint[], isFocused = false) => {
if (isFocused && points.some(point => isIntersectingPoint(point, p.x, p.y))) return true
const isIntersectingPolygon = (p: ImageAnnotatorPoint, points: ImageAnnotatorPoint[], isFocused = false, zoom: F = 1) => {
if (isFocused && points.some(point => isIntersectingPoint(point, p.x, p.y, ARC_RADIUS_DEFAULT / zoom))) return true
let windingNumber = 0

points.forEach((point, idx) => {
Expand All @@ -244,13 +243,12 @@ export

return windingNumber !== 0
},
isIntersectingPoint = ({ x, y }: ImageAnnotatorPoint, cursor_x: F, cursor_y: F) => {
// TODO: Divide ARC_RADIUS by "scale" to make the offset lower when the image is zoomed in.
const offset = 2 * ARC_RADIUS
isIntersectingPoint = ({ x, y }: ImageAnnotatorPoint, cursor_x: F, cursor_y: F, arcRadius: F = ARC_RADIUS_DEFAULT) => {
const offset = 2 * arcRadius
return cursor_x >= x - offset && cursor_x <= x + offset && cursor_y >= y - offset && cursor_y < y + offset
},
getPolygonPointCursor = (items: DrawnPoint[], cursor_x: F, cursor_y: F) => {
const intersectedPoint = items.find(p => isIntersectingPoint(p, cursor_x, cursor_y))
getPolygonPointCursor = (items: DrawnPoint[], cursor_x: F, cursor_y: F, zoom: F = 1) => {
const intersectedPoint = items.find(p => isIntersectingPoint(p, cursor_x, cursor_y, ARC_RADIUS_DEFAULT / zoom))
return intersectedPoint?.isAux
? 'pointer'
: intersectedPoint
Expand Down
42 changes: 22 additions & 20 deletions ui/src/image_annotator_rect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@ import { DrawnShape, ImageAnnotatorRect, Position } from './image_annotator'
const
MIN_RECT_WIDTH = 5,
MIN_RECT_HEIGHT = 5
export const ARC_RADIUS = 4
export const ARC_RADIUS_DEFAULT = 4

// Needs some canvas-related refactoring love.
export class RectAnnotator {
private resizedCorner?: 'topLeft' | 'topRight' | 'bottomLeft' | 'bottomRight'
private movedRect?: DrawnShape
private getArcRadius = () => this.ctx ? ARC_RADIUS_DEFAULT / this.ctx.getTransform().a : ARC_RADIUS_DEFAULT

constructor(private canvas: HTMLCanvasElement, private ctx: CanvasRenderingContext2D | null) { }

drawCircle = (x: U, y: U) => {
drawCircle = (x: U, y: U, strokeColor: S) => {
if (!this.ctx) return

const path = new Path2D()
path.arc(x, y, ARC_RADIUS, 0, 2 * Math.PI)
this.ctx.strokeStyle = '#000'
path.arc(x, y, this.getArcRadius(), 0, 2 * Math.PI)
this.ctx.strokeStyle = strokeColor
this.ctx.fillStyle = '#FFF'
this.ctx.fill(path)
this.ctx.stroke(path)
Expand All @@ -27,13 +29,13 @@ export class RectAnnotator {
if (!this.ctx) return
this.ctx.strokeStyle = strokeColor
this.ctx.strokeRect(x1, y1, x2 - x1, y2 - y1)
this.ctx.fillStyle = strokeColor.substring(0, strokeColor.length - 2) + '0.2)'
this.ctx.fillRect(x1, y1, x2 - x1, y2 - y1)
if (isFocused) {
this.ctx.fillStyle = strokeColor.substring(0, strokeColor.length - 2) + '0.2)'
this.ctx.fillRect(x1, y1, x2 - x1, y2 - y1)
this.drawCircle(x1, y1)
this.drawCircle(x2, y1)
this.drawCircle(x2, y2)
this.drawCircle(x1, y2)
this.drawCircle(x1, y1, strokeColor)
this.drawCircle(x2, y1, strokeColor)
this.drawCircle(x2, y2, strokeColor)
this.drawCircle(x1, y2, strokeColor)
}
}

Expand Down Expand Up @@ -72,7 +74,7 @@ export class RectAnnotator {
onMouseDown(cursor_x: U, cursor_y: U, shape: DrawnShape) {
if (!shape.shape.rect) return
this.movedRect = shape
this.resizedCorner = getCorner(cursor_x, cursor_y, shape.shape.rect)
this.resizedCorner = getCorner(cursor_x, cursor_y, shape.shape.rect, this.getArcRadius())
}

isMovedOrResized = () => !!this.movedRect || !!this.resizedCorner
Expand Down Expand Up @@ -149,23 +151,23 @@ export const
if (x1 > x2) [rect.x1, rect.x2] = [x2, x1]
if (y1 > y2) [rect.y1, rect.y2] = [y2, y1]
},
isIntersectingRect = (cursor_x: U, cursor_y: U, rect?: ImageAnnotatorRect, isFocused = false) => {
isIntersectingRect = (cursor_x: U, cursor_y: U, rect?: ImageAnnotatorRect, isFocused = false, zoom: F = 1) => {
if (!rect) return false
if (isFocused && getCorner(cursor_x, cursor_y, rect)) return true
if (isFocused && getCorner(cursor_x, cursor_y, rect, ARC_RADIUS_DEFAULT / zoom)) return true

const { x2, x1, y2, y1 } = rect
return cursor_x >= x1 && cursor_x <= x2 && cursor_y >= y1 && cursor_y <= y2
},
getCorner = (x: U, y: U, { x1, y1, x2, y2 }: ImageAnnotatorRect) => {
if (x > x1 - ARC_RADIUS && x < x1 + ARC_RADIUS && y > y1 - ARC_RADIUS && y < y1 + ARC_RADIUS) return 'topLeft'
else if (x > x2 - ARC_RADIUS && x < x2 + ARC_RADIUS && y > y1 - ARC_RADIUS && y < y1 + ARC_RADIUS) return 'topRight'
else if (x > x1 - ARC_RADIUS && x < x1 + ARC_RADIUS && y > y2 - ARC_RADIUS && y < y2 + ARC_RADIUS) return 'bottomLeft'
else if (x > x2 - ARC_RADIUS && x < x2 + ARC_RADIUS && y > y2 - ARC_RADIUS && y < y2 + ARC_RADIUS) return 'bottomRight'
getCorner = (x: U, y: U, { x1, y1, x2, y2 }: ImageAnnotatorRect, arcRadius: F = ARC_RADIUS_DEFAULT) => {
if (x > x1 - arcRadius && x < x1 + arcRadius && y > y1 - arcRadius && y < y1 + arcRadius) return 'topLeft'
else if (x > x2 - arcRadius && x < x2 + arcRadius && y > y1 - arcRadius && y < y1 + arcRadius) return 'topRight'
else if (x > x1 - arcRadius && x < x1 + arcRadius && y > y2 - arcRadius && y < y2 + arcRadius) return 'bottomLeft'
else if (x > x2 - arcRadius && x < x2 + arcRadius && y > y2 - arcRadius && y < y2 + arcRadius) return 'bottomRight'
},
getRectCursorByCorner = (corner?: 'topLeft' | 'topRight' | 'bottomLeft' | 'bottomRight') => {
if (corner === 'topLeft' || corner === 'bottomRight') return 'nwse-resize'
if (corner === 'bottomLeft' || corner === 'topRight') return 'nesw-resize'
},
getRectCornerCursor = (shape: ImageAnnotatorRect, cursor_x: U, cursor_y: U) => {
return getRectCursorByCorner(getCorner(cursor_x, cursor_y, shape))
getRectCornerCursor = (shape: ImageAnnotatorRect, cursor_x: U, cursor_y: U, zoom: F = 1) => {
return getRectCursorByCorner(getCorner(cursor_x, cursor_y, shape, ARC_RADIUS_DEFAULT / zoom))
}

0 comments on commit fc273d9

Please sign in to comment.