Skip to content

Commit

Permalink
[GSoC2024]Import annotations keeping current ones(cvat-ai#4747)
Browse files Browse the repository at this point in the history
Keep current annotations without deleting them, adding the imported ones.
  • Loading branch information
EBayego committed Apr 15, 2024
1 parent fbc2610 commit aba9401
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 16 deletions.
12 changes: 10 additions & 2 deletions cvat-ui/src/actions/import-actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export const importDatasetAsync = (
sourceStorage: Storage,
file: File | string,
convMaskToPoly: boolean,
keepOldAnnotations: boolean,
): ThunkAction => (
async (dispatch, getState) => {
const resource = instance instanceof core.classes.Project ? 'dataset' : 'annotation';
Expand All @@ -86,6 +87,7 @@ export const importDatasetAsync = (
await instance.annotations
.importDataset(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
updateStatusCallback: (message: string, progress: number) => (
dispatch(importActions.importDatasetUpdateStatus(
instance, Math.floor(progress * 100), message,
Expand All @@ -97,7 +99,10 @@ export const importDatasetAsync = (
throw Error('Only one importing of annotation/dataset allowed at the same time');
}
dispatch(importActions.importDataset(instance, format));
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, { convMaskToPoly });
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
});
} else { // job
if (state.import.tasks.dataset.current?.[instance.taskId]) {
throw Error('Annotations is being uploaded for the task');
Expand All @@ -108,7 +113,10 @@ export const importDatasetAsync = (

dispatch(importActions.importDataset(instance, format));

await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, { convMaskToPoly });
await instance.annotations.upload(format, useDefaultSettings, sourceStorage, file, {
convMaskToPoly,
keepOldAnnotations,
});
await instance.logger.log(EventScope.uploadAnnotations);
await instance.annotations.clear(true);
await instance.actions.clear();
Expand Down
34 changes: 33 additions & 1 deletion cvat-ui/src/components/import-dataset/import-dataset-modal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const initialValues: FormValues = {
interface UploadParams {
resource: 'annotation' | 'dataset' | null;
convMaskToPoly: boolean;
keepOldAnnotations: boolean;
useDefaultSettings: boolean;
sourceStorage: Storage;
selectedFormat: string | null;
Expand Down Expand Up @@ -83,6 +84,7 @@ enum ReducerActionType {
SET_FILE_NAME = 'SET_FILE_NAME',
SET_SELECTED_FORMAT = 'SET_SELECTED_FORMAT',
SET_CONV_MASK_TO_POLY = 'SET_CONV_MASK_TO_POLY',
SET_KEEP_OLD_ANNOTATIONS = 'SET_KEEP_OLD_ANNOTATIONS',
SET_SOURCE_STORAGE = 'SET_SOURCE_STORAGE',
SET_RESOURCE = 'SET_RESOURCE',
}
Expand Down Expand Up @@ -121,6 +123,9 @@ export const reducerActions = {
setConvMaskToPoly: (convMaskToPoly: boolean) => (
createAction(ReducerActionType.SET_CONV_MASK_TO_POLY, { convMaskToPoly })
),
setKeepOldAnnotations: (keepOldAnnotations: boolean) => (
createAction(ReducerActionType.SET_KEEP_OLD_ANNOTATIONS, { keepOldAnnotations })
),
setSourceStorage: (sourceStorage: Storage) => (
createAction(ReducerActionType.SET_SOURCE_STORAGE, { sourceStorage })
),
Expand Down Expand Up @@ -246,6 +251,16 @@ const reducer = (state: State, action: ActionUnion<typeof reducerActions>): Stat
};
}

if (action.type === ReducerActionType.SET_KEEP_OLD_ANNOTATIONS) {
return {
...state,
uploadParams: {
...state.uploadParams,
keepOldAnnotations: action.payload.keepOldAnnotations,
},
};
}

if (action.type === ReducerActionType.SET_SOURCE_STORAGE) {
return {
...state,
Expand Down Expand Up @@ -292,6 +307,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
uploadParams: {
resource: null,
convMaskToPoly: true,
keepOldAnnotations: true,
useDefaultSettings: true,
sourceStorage: new Storage({
location: StorageLocation.LOCAL,
Expand Down Expand Up @@ -460,6 +476,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
uploadParams.sourceStorage,
uploadParams.file || uploadParams.fileName as string,
uploadParams.convMaskToPoly,
uploadParams.keepOldAnnotations,
));
const resToPrint = uploadParams.resource.charAt(0).toUpperCase() + uploadParams.resource.slice(1);
Notification.info({
Expand Down Expand Up @@ -488,7 +505,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {

const handleImport = useCallback(
(): void => {
if (isAnnotation()) {
if (isAnnotation() && !uploadParams.keepOldAnnotations) {
confirmUpload();
} else {
onUpload();
Expand Down Expand Up @@ -538,6 +555,7 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
initialValues={{
...initialValues,
convMaskToPoly: uploadParams.convMaskToPoly,
keepOldAnnotations: uploadParams.keepOldAnnotations,
}}
onFinish={handleImport}
layout='vertical'
Expand Down Expand Up @@ -588,6 +606,20 @@ function ImportDatasetModal(props: StateToProps): JSX.Element {
)}
</Select>
</Form.Item>
<Space className='cvat-modal-import-switch-keep-old-annotations-container'>
<Form.Item
name='keepOldAnnotations'
valuePropName='checked'
className='cvat-modal-import-switch-keep-old-annotations'
>
<Switch
onChange={(value: boolean) => {
dispatch(reducerActions.setKeepOldAnnotations(value));
}}
/>
</Form.Item>
<Text strong>Keep Current Annotations</Text>
</Space>
<Space className='cvat-modal-import-switch-conv-mask-to-poly-container'>
<Form.Item
name='convMaskToPoly'
Expand Down
6 changes: 6 additions & 0 deletions cvat-ui/src/components/import-dataset/styles.scss
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@
.cvat-modal-import-switch-conv-mask-to-poly {
display: table-cell;
}
.cvat-modal-import-switch-keep-old-annotations {
display: table-cell;
}

.cvat-modal-import-switch-use-default-storage-container,
.cvat-modal-import-switch-conv-mask-to-poly-container {
width: 100%;
}
.cvat-modal-import-switch-keep-old-annotations-container {
width: 100%;
}
17 changes: 11 additions & 6 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,10 @@ def import_annotations(self, src_file, importer, **options):
db_job=self.db_job,
create_callback=self.create,
)
self.delete()

keep_old_annotations = options.get('keep_old_annotations', True)
if not keep_old_annotations:
self.delete()

temp_dir_base = self.db_job.get_tmp_dirname()
os.makedirs(temp_dir_base, exist_ok=True)
Expand Down Expand Up @@ -796,7 +799,9 @@ def import_annotations(self, src_file, importer, **options):
db_task=self.db_task,
create_callback=self.create,
)
self.delete()
keep_old_annotations = options.get('keep_old_annotations', True)
if not keep_old_annotations:
self.delete()

temp_dir_base = self.db_task.get_tmp_dirname()
os.makedirs(temp_dir_base, exist_ok=True)
Expand Down Expand Up @@ -910,25 +915,25 @@ def export_task(task_id, dst_file, format_name, server_url=None, save_images=Fal
task.export(f, exporter, host=server_url, save_images=save_images)

@transaction.atomic
def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly):
def import_task_annotations(src_file, task_id, format_name, conv_mask_to_poly, keep_old_annotations):
task = TaskAnnotation(task_id)
task.init_from_db()

importer = make_importer(format_name)
with open(src_file, 'rb') as f:
try:
task.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly)
task.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly, keep_old_annotations=keep_old_annotations)
except (DatasetError, DatasetImportError, DatasetNotFoundError) as ex:
raise CvatImportError(str(ex))

@transaction.atomic
def import_job_annotations(src_file, job_id, format_name, conv_mask_to_poly):
def import_job_annotations(src_file, job_id, format_name, conv_mask_to_poly, keep_old_annotations):
job = JobAnnotation(job_id)
job.init_from_db()

importer = make_importer(format_name)
with open(src_file, 'rb') as f:
try:
job.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly)
job.import_annotations(f, importer, conv_mask_to_poly=conv_mask_to_poly, keep_old_annotations=keep_old_annotations)
except (DatasetError, DatasetImportError, DatasetNotFoundError) as ex:
raise CvatImportError(str(ex))
28 changes: 21 additions & 7 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def upload_finished(self, request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self._object.get_tmp_dirname()
uploaded_file = None
if os.path.isfile(os.path.join(tmp_dir, filename)):
Expand All @@ -429,7 +430,8 @@ def upload_finished(self, request):
rq_func=dm.project.import_dataset_as_project,
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
elif self.action == 'import_backup':
filename = request.query_params.get("filename", "")
Expand Down Expand Up @@ -1003,6 +1005,7 @@ def _handle_upload_annotations(request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self._object.get_tmp_dirname()
if os.path.isfile(os.path.join(tmp_dir, filename)):
annotation_file = os.path.join(tmp_dir, filename)
Expand All @@ -1014,6 +1017,7 @@ def _handle_upload_annotations(request):
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations,
)
return Response(data='No such file were uploaded',
status=status.HTTP_400_BAD_REQUEST)
Expand Down Expand Up @@ -1347,19 +1351,22 @@ def annotations(self, request, pk):
elif request.method == 'POST' or request.method == 'OPTIONS':
# NOTE: initialization process of annotations import
format_name = request.query_params.get('format', '')
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
return self.import_annotations(
request=request,
db_obj=self._object,
import_func=_import_annotations,
rq_func=dm.task.import_task_annotations,
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE,
keep_old_annotations=keep_old_annotations
)
elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
if format_name:
# NOTE: continue process of import annotations
use_settings = to_bool(request.query_params.get('use_default_location', True))
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
obj = self._object if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj, use_settings=use_settings, field_name=StorageType.SOURCE
Expand All @@ -1371,7 +1378,8 @@ def annotations(self, request, pk):
db_obj=self._object,
format_name=format_name,
location_conf=location_conf,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
else:
serializer = LabeledDataSerializer(data=request.data)
Expand Down Expand Up @@ -1647,6 +1655,7 @@ def upload_finished(self, request):
format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "")
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
tmp_dir = self.get_upload_dir()
if os.path.isfile(os.path.join(tmp_dir, filename)):
annotation_file = os.path.join(tmp_dir, filename)
Expand All @@ -1658,6 +1667,7 @@ def upload_finished(self, request):
db_obj=self._object,
format_name=format_name,
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations,
)
else:
return Response(data='No such file were uploaded',
Expand Down Expand Up @@ -1789,19 +1799,22 @@ def annotations(self, request, pk):

elif request.method == 'POST' or request.method == 'OPTIONS':
format_name = request.query_params.get('format', '')
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
return self.import_annotations(
request=request,
db_obj=self._object,
import_func=_import_annotations,
rq_func=dm.task.import_job_annotations,
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE
rq_id_template=self.IMPORT_RQ_ID_TEMPLATE,
keep_old_annotations=keep_old_annotations
)

elif request.method == 'PUT':
format_name = request.query_params.get('format', '')
if format_name:
use_settings = to_bool(request.query_params.get('use_default_location', True))
conv_mask_to_poly = to_bool(request.query_params.get('conv_mask_to_poly', True))
keep_old_annotations = to_bool(request.query_params.get('keep_old_annotations', True))
obj = self._object.segment.task if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj, use_settings=use_settings, field_name=StorageType.SOURCE
Expand All @@ -1813,7 +1826,8 @@ def annotations(self, request, pk):
db_obj=self._object,
format_name=format_name,
location_conf=location_conf,
conv_mask_to_poly=conv_mask_to_poly
conv_mask_to_poly=conv_mask_to_poly,
keep_old_annotations=keep_old_annotations
)
else:
serializer = LabeledDataSerializer(data=request.data)
Expand Down Expand Up @@ -2813,7 +2827,7 @@ def rq_exception_handler(rq_job, exc_type, exc_value, tb):
return True

def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
filename=None, location_conf=None, conv_mask_to_poly=True):
filename=None, location_conf=None, conv_mask_to_poly=True, keep_old_annotations=True):

format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_import_formats()}.get(format_name)
Expand Down Expand Up @@ -2882,7 +2896,7 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
filename = tf.name

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly, keep_old_annotations)

if location == Location.CLOUD_STORAGE:
func_args = (db_storage, key, func) + func_args
Expand Down

0 comments on commit aba9401

Please sign in to comment.