diff --git a/CHANGELOG.md b/CHANGELOG.md index a93ee008eb45..0d4fcbb3080b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428) () - Double modal export/backup a task/project () - Fixed bug of computing Job's unsolved/resolved issues numbers () +- Dataset export for job () ### Security - TDB diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 1f1177b8c840..71fb144dff13 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -161,8 +161,7 @@ def _export_attributes(self, attributes): )) return exported_attributes - -class TaskData(InstanceLabelData): +class CommonData(InstanceLabelData): Shape = namedtuple("Shape", 'id, label_id') # 3d LabeledShape = namedtuple( 'LabeledShape', 'type, frame, label, points, occluded, attributes, source, rotation, group, z_order, elements, outside') @@ -178,48 +177,69 @@ class TaskData(InstanceLabelData): 'Frame', 'idx, id, frame, name, width, height, labeled_shapes, tags, shapes, labels') Labels = namedtuple('Label', 'id, name, color, type') - def __init__(self, annotation_ir, db_task, host='', create_callback=None): + def __init__(self, annotation_ir, db_task, host='', create_callback=None) -> None: self._annotation_ir = annotation_ir - self._db_task = db_task self._host = host self._create_callback = create_callback self._MAX_ANNO_SIZE = 30000 self._frame_info = {} self._frame_mapping = {} self._frame_step = db_task.data.get_frame_step() + self._db_data = db_task.data - InstanceLabelData.__init__(self, db_task) + super().__init__(db_task) self._init_frame_info() self._init_meta() + @property + def rel_range(self): + raise NotImplementedError() + + @property + def start(self) -> int: + return 0 + + @property + def stop(self) -> int: + return len(self) + + def _get_queryset(self): + raise NotImplementedError() + def abs_frame_id(self, relative_id): - if relative_id not in range(0, self._db_task.data.size): + # relative_id is frame index in segment for job, so it can start with more than just zero + if relative_id not in self.rel_range: raise ValueError("Unknown internal frame id %s" % relative_id) - return relative_id * self._frame_step + self._db_task.data.start_frame + return relative_id * self._frame_step + self._db_data.start_frame def rel_frame_id(self, absolute_id): d, m = divmod( - absolute_id - self._db_task.data.start_frame, self._frame_step) - if m or d not in range(0, self._db_task.data.size): + absolute_id - self._db_data.start_frame, self._frame_step) + if m or d not in self.rel_range: raise ValueError("Unknown frame %s" % absolute_id) return d def _init_frame_info(self): - self._deleted_frames = { k: True for k in self._db_task.data.deleted_frames } - if hasattr(self._db_task.data, 'video'): - self._frame_info = {frame: { - "path": "frame_{:06d}".format(self.abs_frame_id(frame)), - "width": self._db_task.data.video.width, - "height": self._db_task.data.video.height, - } for frame in range(self._db_task.data.size)} + self._deleted_frames = { k: True for k in self._db_data.deleted_frames } + if hasattr(self._db_data, 'video'): + self._frame_info = { + frame: { + "path": "frame_{:06d}".format(self.abs_frame_id(frame)), + "width": self._db_data.video.width, + "height": self._db_data.video.height, + } for frame in self.rel_range + } else: - self._frame_info = {self.rel_frame_id(db_image.frame): { - "id": db_image.id, - "path": db_image.path, - "width": db_image.width, - "height": db_image.height, - } for db_image in self._db_task.data.images.all()} + queryset = self._get_queryset() + self._frame_info = { + self.rel_frame_id(db_image.frame): { + "id": db_image.id, + "path": db_image.path, + "width": db_image.width, + "height": db_image.height, + } for db_image in queryset + } self._frame_mapping = { self._get_filename(info["path"]): frame_number @@ -227,94 +247,39 @@ def _init_frame_info(self): } @staticmethod - def meta_for_task(db_task, host, label_mapping=None): - db_segments = db_task.segment_set.all().prefetch_related('job_set') - - meta = OrderedDict([ - ("id", str(db_task.id)), - ("name", db_task.name), - ("size", str(db_task.data.size)), - ("mode", db_task.mode), - ("overlap", str(db_task.overlap)), - ("bugtracker", db_task.bug_tracker), - ("created", str(timezone.localtime(db_task.created_date))), - ("updated", str(timezone.localtime(db_task.updated_date))), - ("subset", db_task.subset or dm.DEFAULT_SUBSET_NAME), - ("start_frame", str(db_task.data.start_frame)), - ("stop_frame", str(db_task.data.stop_frame)), - ("frame_filter", db_task.data.frame_filter), - - ("segments", [ - ("segment", OrderedDict([ - ("id", str(db_segment.id)), - ("start", str(db_segment.start_frame)), - ("stop", str(db_segment.stop_frame)), - ("url", "{}/?id={}".format( - host, db_segment.job_set.all()[0].id))] - )) for db_segment in db_segments - ]), - - ("owner", OrderedDict([ - ("username", db_task.owner.username), - ("email", db_task.owner.email) - ]) if db_task.owner else ""), - - ("assignee", OrderedDict([ - ("username", db_task.assignee.username), - ("email", db_task.assignee.email) - ]) if db_task.assignee else ""), - ]) - - if label_mapping is not None: - labels = [] - for db_label in label_mapping.values(): - label = OrderedDict([ - ("name", db_label.name), - ("color", db_label.color), - ("type", db_label.type), - ("attributes", [ - ("attribute", OrderedDict([ - ("name", db_attr.name), - ("mutable", str(db_attr.mutable)), - ("input_type", db_attr.input_type), - ("default_value", db_attr.default_value), - ("values", db_attr.values)])) - for db_attr in db_label.attributespec_set.all()]) - ]) - - if db_label.parent: - label["parent"] = db_label.parent.name - - if db_label.type == str(LabelType.SKELETON): - label["svg"] = db_label.skeleton.svg - for db_sublabel in list(db_label.sublabels.all()): - label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"') + def _convert_db_labels(db_labels): + labels = [] + for db_label in db_labels: + label = OrderedDict([ + ("name", db_label.name), + ("color", db_label.color), + ("type", db_label.type), + ("attributes", [ + ("attribute", OrderedDict([ + ("name", db_attr.name), + ("mutable", str(db_attr.mutable)), + ("input_type", db_attr.input_type), + ("default_value", db_attr.default_value), + ("values", db_attr.values)])) + for db_attr in db_label.attributespec_set.all()]) + ]) - labels.append(('label', label)) + if db_label.parent: + label["parent"] = db_label.parent.name - meta['labels'] = labels + if db_label.type == str(LabelType.SKELETON): + label["svg"] = db_label.skeleton.svg + for db_sublabel in list(db_label.sublabels.all()): + label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"') - if hasattr(db_task.data, "video"): - meta["original_size"] = OrderedDict([ - ("width", str(db_task.data.video.width)), - ("height", str(db_task.data.video.height)) - ]) - - return meta + labels.append(('label', label)) + return labels def _init_meta(self): - self._meta = OrderedDict([ - ("task", self.meta_for_task(self._db_task, self._host, self._label_mapping)), - ("dumped", str(timezone.localtime(timezone.now()))) - ]) - - if hasattr(self._db_task.data, "video"): - # Add source to dumped file - self._meta["source"] = str( - osp.basename(self._db_task.data.video.path)) + raise NotImplementedError() def _export_tracked_shape(self, shape): - return TaskData.TrackedShape( + return CommonData.TrackedShape( type=shape["type"], frame=self.abs_frame_id(shape["frame"]), label=self._get_label_name(shape["label_id"]), @@ -332,7 +297,7 @@ def _export_tracked_shape(self, shape): ) def _export_labeled_shape(self, shape): - return TaskData.LabeledShape( + return CommonData.LabeledShape( type=shape["type"], label=self._get_label_name(shape["label_id"]), frame=self.abs_frame_id(shape["frame"]), @@ -348,13 +313,13 @@ def _export_labeled_shape(self, shape): ) def _export_shape(self, shape): - return TaskData.Shape( + return CommonData.Shape( id=shape["id"], label_id=shape["label_id"] ) def _export_tag(self, tag): - return TaskData.Tag( + return CommonData.Tag( frame=self.abs_frame_id(tag["frame"]), label=self._get_label_name(tag["label_id"]), group=tag.get("group", 0), @@ -363,9 +328,9 @@ def _export_tag(self, tag): ) def _export_track(self, track, idx): - track['shapes'] = list(filter(lambda x: x['frame'] not in self._deleted_frames, track['shapes'])) + track['shapes'] = list(filter(lambda x: not self._is_frame_deleted(x['frame']), track['shapes'])) tracked_shapes = TrackManager.get_interpolated_shapes( - track, 0, self._db_task.data.size) + track, 0, len(self)) for tracked_shape in tracked_shapes: tracked_shape["attributes"] += track["attributes"] tracked_shape["track_id"] = idx @@ -373,18 +338,18 @@ def _export_track(self, track, idx): tracked_shape["source"] = track["source"] tracked_shape["label_id"] = track["label_id"] - return TaskData.Track( + return CommonData.Track( label=self._get_label_name(track["label_id"]), group=track["group"], source=track["source"], shapes=[self._export_tracked_shape(shape) - for shape in tracked_shapes if shape["frame"] not in self._deleted_frames], + for shape in tracked_shapes if not self._is_frame_deleted(shape["frame"])], elements=[self._export_track(element, i) for i, element in enumerate(track.get("elements", []))] ) @staticmethod def _export_label(label): - return TaskData.Labels( + return CommonData.Labels( id=label.id, name=label.name, color=label.color, @@ -397,11 +362,11 @@ def get_frame(idx): frame_info = self._frame_info[idx] frame = self.abs_frame_id(idx) if frame not in frames: - frames[frame] = TaskData.Frame( + frames[frame] = CommonData.Frame( idx=idx, - id=frame_info.get('id',0), + id=frame_info.get("id", 0), frame=frame, - name=frame_info['path'], + name=frame_info["path"], height=frame_info["height"], width=frame_info["width"], labeled_shapes=[], @@ -413,14 +378,14 @@ def get_frame(idx): if include_empty: for idx in self._frame_info: - if idx not in self._deleted_frames: + if not self._is_frame_deleted(idx): get_frame(idx) anno_manager = AnnotationManager(self._annotation_ir) shape_data = '' - for shape in sorted(anno_manager.to_shapes(self._db_task.data.size), + for shape in sorted(anno_manager.to_shapes(len(self)), key=lambda shape: shape.get("z_order", 0)): - if shape['frame'] not in self._frame_info or shape['frame'] in self._deleted_frames: + if shape['frame'] not in self._frame_info or self._is_frame_deleted(shape['frame']): # After interpolation there can be a finishing frame # outside of the task boundaries. Filter it out to avoid errors. # https://github.com/openvinotoolkit/cvat/issues/2827 @@ -450,9 +415,12 @@ def get_frame(idx): @property def shapes(self): for shape in self._annotation_ir.shapes: - if shape["frame"] not in self._deleted_frames: + if not self._is_frame_deleted(shape["frame"]): yield self._export_labeled_shape(shape) + def _is_frame_deleted(self, frame): + return frame in self._deleted_frames + @property def tracks(self): for idx, track in enumerate(self._annotation_ir.tracks): @@ -577,8 +545,15 @@ def frame_step(self): return self._frame_step @property - def db_task(self): - return self._db_task + def db_instance(self): + raise NotImplementedError() + + @property + def db_data(self): + return self._db_data + + def __len__(self): + raise NotImplementedError() @staticmethod def _get_filename(path): @@ -605,7 +580,176 @@ def match_frame_fuzzy(self, path): return v return None +class JobData(CommonData): + META_FIELD = "job" + def __init__(self, annotation_ir, db_job, host='', create_callback=None): + self._db_job = db_job + self._db_task = db_job.segment.task + + super().__init__(annotation_ir, self._db_task, host, create_callback) + + def _init_meta(self): + db_segment = self._db_job.segment + self._meta = OrderedDict([ + (JobData.META_FIELD, OrderedDict([ + ("id", str(self._db_job.id)), + ("size", str(len(self))), + ("mode", self._db_task.mode), + ("overlap", str(self._db_task.overlap)), + ("bugtracker", self._db_task.bug_tracker), + ("created", str(timezone.localtime(self._db_task.created_date))), + ("updated", str(timezone.localtime(self._db_job.updated_date))), + ("subset", self._db_task.subset or dm.DEFAULT_SUBSET_NAME), + ("start_frame", str(self._db_data.start_frame + db_segment.start_frame * self._frame_step)), + ("stop_frame", str(self._db_data.start_frame + db_segment.stop_frame * self._frame_step)), + ("frame_filter", self._db_data.frame_filter), + ("segments", [ + ("segment", OrderedDict([ + ("id", str(db_segment.id)), + ("start", str(db_segment.start_frame)), + ("stop", str(db_segment.stop_frame)), + ("url", "{}/api/jobs/{}".format(self._host, self._db_job.id))])), + ]), + ("owner", OrderedDict([ + ("username", self._db_task.owner.username), + ("email", self._db_task.owner.email) + ]) if self._db_task.owner else ""), + + ("assignee", OrderedDict([ + ("username", self._db_job.assignee.username), + ("email", self._db_job.assignee.email) + ]) if self._db_job.assignee else ""), + ])), + ("dumped", str(timezone.localtime(timezone.now()))), + ]) + + if self._label_mapping is not None: + self._meta[JobData.META_FIELD]["labels"] = CommonData._convert_db_labels(self._label_mapping.values()) + + if hasattr(self._db_data, "video"): + self._meta["original_size"] = OrderedDict([ + ("width", str(self._db_data.video.width)), + ("height", str(self._db_data.video.height)) + ]) + + def __len__(self): + segment = self._db_job.segment + return segment.stop_frame - segment.start_frame + 1 + + def _get_queryset(self): + return self._db_data.images.filter(frame__in=self.abs_range) + + @property + def abs_range(self): + segment = self._db_job.segment + step = self._frame_step + start_frame = self._db_data.start_frame + segment.start_frame * step + stop_frame = self._db_data.start_frame + segment.stop_frame * step + 1 + + return range(start_frame, stop_frame, step) + + @property + def rel_range(self): + segment = self._db_job.segment + return range(segment.start_frame, segment.stop_frame + 1) + + @property + def start(self) -> int: + segment = self._db_job.segment + return segment.start_frame + + @property + def stop(self) -> int: + segment = self._db_job.segment + return segment.stop_frame + 1 + + @property + def db_instance(self): + return self._db_job + +class TaskData(CommonData): + META_FIELD = "task" + def __init__(self, annotation_ir, db_task, host='', create_callback=None): + self._db_task = db_task + super().__init__(annotation_ir, db_task, host, create_callback) + + @staticmethod + def meta_for_task(db_task, host, label_mapping=None): + db_segments = db_task.segment_set.all().prefetch_related('job_set') + + meta = OrderedDict([ + ("id", str(db_task.id)), + ("name", db_task.name), + ("size", str(db_task.data.size)), + ("mode", db_task.mode), + ("overlap", str(db_task.overlap)), + ("bugtracker", db_task.bug_tracker), + ("created", str(timezone.localtime(db_task.created_date))), + ("updated", str(timezone.localtime(db_task.updated_date))), + ("subset", db_task.subset or dm.DEFAULT_SUBSET_NAME), + ("start_frame", str(db_task.data.start_frame)), + ("stop_frame", str(db_task.data.stop_frame)), + ("frame_filter", db_task.data.frame_filter), + + ("segments", [ + ("segment", OrderedDict([ + ("id", str(db_segment.id)), + ("start", str(db_segment.start_frame)), + ("stop", str(db_segment.stop_frame)), + ("url", "{}/api/jobs/{}".format( + host, db_segment.job_set.all()[0].id))] + )) for db_segment in db_segments + ]), + + ("owner", OrderedDict([ + ("username", db_task.owner.username), + ("email", db_task.owner.email) + ]) if db_task.owner else ""), + + ("assignee", OrderedDict([ + ("username", db_task.assignee.username), + ("email", db_task.assignee.email) + ]) if db_task.assignee else ""), + ]) + + if label_mapping is not None: + meta['labels'] = CommonData._convert_db_labels(label_mapping.values()) + + if hasattr(db_task.data, "video"): + meta["original_size"] = OrderedDict([ + ("width", str(db_task.data.video.width)), + ("height", str(db_task.data.video.height)) + ]) + + return meta + + def _init_meta(self): + self._meta = OrderedDict([ + (TaskData.META_FIELD, self.meta_for_task(self._db_task, self._host, self._label_mapping)), + ("dumped", str(timezone.localtime(timezone.now()))) + ]) + + if hasattr(self._db_task.data, "video"): + # Add source to dumped file + self._meta["source"] = str( + osp.basename(self._db_task.data.video.path)) + + def __len__(self): + return self._db_data.size + + @property + def rel_range(self): + return range(len(self)) + + @property + def db_instance(self): + return self._db_task + + def _get_queryset(self): + return self._db_data.images.all() + class ProjectData(InstanceLabelData): + META_FIELD = 'project' @attrs class LabeledShape: type: str = attrib() @@ -765,7 +909,7 @@ def _init_frame_info(self): def _init_meta(self): self._meta = OrderedDict([ - ('project', OrderedDict([ + (ProjectData.META_FIELD, OrderedDict([ ('id', str(self._db_project.id)), ('name', self._db_project.name), ("bugtracker", self._db_project.bug_tracker), @@ -819,7 +963,7 @@ def _init_meta(self): labels.append(('label', label)) - self._meta['project']['labels'] = labels + self._meta[ProjectData.META_FIELD]['labels'] = labels def _export_tracked_shape(self, shape: dict, task_id: int): return ProjectData.TrackedShape( @@ -1040,7 +1184,7 @@ def split_dataset(self, dataset: dm.Dataset): for task_data in self.task_data: if task_data._db_task.id not in self.new_tasks: continue - subset_dataset: dm.Dataset = dataset.subsets()[task_data.db_task.subset].as_dataset() + subset_dataset: dm.Dataset = dataset.subsets()[task_data.db_instance.subset].as_dataset() yield subset_dataset, task_data def add_labels(self, labels: List[dict]): @@ -1096,7 +1240,7 @@ def _load_user_info(meta: dict): "updatedAt": meta['updated'] } - def _read_cvat_anno(self, cvat_frame_anno: Union[ProjectData.Frame, TaskData.Frame], labels: list): + def _read_cvat_anno(self, cvat_frame_anno: Union[ProjectData.Frame, CommonData.Frame], labels: list): categories = self.categories() label_cat = categories[dm.AnnotationType.label] def map_label(name, parent=''): return label_cat.find(name, parent)[0] @@ -1108,16 +1252,17 @@ def map_label(name, parent=''): return label_cat.find(name, parent)[0] return convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label) -class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin): - def __init__(self, task_data, include_images=False, format_type=None, dimension=DimensionType.DIM_2D): +class CvatTaskOrJobDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin): + def __init__(self, instance_data: CommonData, include_images=False, format_type=None, dimension=DimensionType.DIM_2D): super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud) - self._categories = self._load_categories(task_data.meta['task']['labels']) - self._user = self._load_user_info(task_data.meta['task']) if dimension == DimensionType.DIM_3D else {} + instance_meta = instance_data.meta[instance_data.META_FIELD] + self._categories = self._load_categories(instance_meta['labels']) + self._user = self._load_user_info(instance_meta) if dimension == DimensionType.DIM_3D else {} self._dimension = dimension self._format_type = format_type dm_items = [] - is_video = task_data.meta['task']['mode'] == 'interpolation' + is_video = instance_meta['mode'] == 'interpolation' ext = '' if is_video: ext = FrameProvider.VIDEO_FRAME_EXT @@ -1125,7 +1270,7 @@ def __init__(self, task_data, include_images=False, format_type=None, dimension= if dimension == DimensionType.DIM_3D: def _make_image(image_id, **kwargs): loader = osp.join( - task_data.db_task.data.get_upload_dirname(), kwargs['path']) + instance_data.db_data.get_upload_dirname(), kwargs['path']) related_images = [] image = Img.objects.get(id=image_id) for i in image.related_files.all(): @@ -1135,7 +1280,7 @@ def _make_image(image_id, **kwargs): return loader, related_images elif include_images: - frame_provider = FrameProvider(task_data.db_task.data) + frame_provider = FrameProvider(instance_data.db_data) if is_video: # optimization for videos: use numpy arrays instead of bytes # some formats or transforms can require image data @@ -1152,7 +1297,7 @@ def _make_image(i, **kwargs): out_type=frame_provider.Type.BUFFER)[0].getvalue() return dm.ByteImage(data=loader, **kwargs) - for frame_data in task_data.group_by_frame(include_empty=True): + for frame_data in instance_data.group_by_frame(include_empty=True): image_args = { 'path': frame_data.name + ext, 'size': (frame_data.height, frame_data.width), @@ -1164,7 +1309,7 @@ def _make_image(i, **kwargs): dm_image = _make_image(frame_data.idx, **image_args) else: dm_image = dm.Image(**image_args) - dm_anno = self._read_cvat_anno(frame_data, task_data.meta['task']['labels']) + dm_anno = self._read_cvat_anno(frame_data, instance_meta['labels']) if dimension == DimensionType.DIM_2D: dm_item = dm.DatasetItem( @@ -1179,7 +1324,7 @@ def _make_image(i, **kwargs): attributes["createdAt"] = self._user["createdAt"] attributes["updatedAt"] = self._user["updatedAt"] attributes["labels"] = [] - for (idx, (_, label)) in enumerate(task_data.meta['task']['labels']): + for (idx, (_, label)) in enumerate(instance_meta['labels']): attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) attributes["track_id"] = -1 @@ -1193,7 +1338,7 @@ def _make_image(i, **kwargs): self._items = dm_items - def _read_cvat_anno(self, cvat_frame_anno: TaskData.Frame, labels: list): + def _read_cvat_anno(self, cvat_frame_anno: CommonData.Frame, labels: list): categories = self.categories() label_cat = categories[dm.AnnotationType.label] def map_label(name, parent=''): return label_cat.find(name, parent)[0] @@ -1207,8 +1352,8 @@ def map_label(name, parent=''): return label_cat.find(name, parent)[0] class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin): def __init__(self, project_data: ProjectData, include_images: bool = False, format_type: str = None, dimension: DimensionType = DimensionType.DIM_2D): super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud) - self._categories = self._load_categories(project_data.meta['project']['labels']) - self._user = self._load_user_info(project_data.meta['project']) if dimension == DimensionType.DIM_3D else {} + self._categories = self._load_categories(project_data.meta[project_data.META_FIELD]['labels']) + self._user = self._load_user_info(project_data.meta[project_data.META_FIELD]) if dimension == DimensionType.DIM_3D else {} self._dimension = dimension self._format_type = format_type @@ -1271,7 +1416,7 @@ def _make_image(i, **kwargs): dm_image = image_maker_per_task[frame_data.task_id](frame_data.idx, **image_args) else: dm_image = dm.Image(**image_args) - dm_anno = self._read_cvat_anno(frame_data, project_data.meta['project']['labels']) + dm_anno = self._read_cvat_anno(frame_data, project_data.meta[project_data.META_FIELD]['labels']) if self._dimension == DimensionType.DIM_2D: dm_item = dm.DatasetItem( id=osp.splitext(frame_data.name)[0], @@ -1286,7 +1431,7 @@ def _make_image(i, **kwargs): attributes["createdAt"] = self._user["createdAt"] attributes["updatedAt"] = self._user["updatedAt"] attributes["labels"] = [] - for (idx, (_, label)) in enumerate(project_data.meta['project']['labels']): + for (idx, (_, label)) in enumerate(project_data.meta[project_data.META_FIELD]['labels']): attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) attributes["track_id"] = -1 @@ -1309,11 +1454,16 @@ def __len__(self): return len(self._items) -def GetCVATDataExtractor(instance_data: Union[ProjectData, TaskData], include_images: bool = False, format_type: str = None, dimension: DimensionType = DimensionType.DIM_2D): +def GetCVATDataExtractor( + instance_data: Union[ProjectData, CommonData], + include_images: bool = False, + format_type: str = None, + dimension: DimensionType = DimensionType.DIM_2D, +): if isinstance(instance_data, ProjectData): return CVATProjectDataExtractor(instance_data, include_images, format_type, dimension) else: - return CvatTaskDataExtractor(instance_data, include_images, format_type, dimension) + return CvatTaskOrJobDataExtractor(instance_data, include_images, format_type, dimension) class CvatImportError(Exception): pass @@ -1469,25 +1619,25 @@ def convert_attrs(label, cvat_attrs): return item_anno -def match_dm_item(item, task_data, root_hint=None): - is_video = task_data.meta['task']['mode'] == 'interpolation' +def match_dm_item(item, instance_data, root_hint=None): + is_video = instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation' frame_number = None if frame_number is None and item.has_image: - frame_number = task_data.match_frame(item.id + item.image.ext, root_hint) + frame_number = instance_data.match_frame(item.id + item.image.ext, root_hint) if frame_number is None: - frame_number = task_data.match_frame(item.id, root_hint, path_has_ext=False) + frame_number = instance_data.match_frame(item.id, root_hint, path_has_ext=False) if frame_number is None: frame_number = dm.util.cast(item.attributes.get('frame', item.id), int) if frame_number is None and is_video: frame_number = dm.util.cast(osp.basename(item.id)[len('frame_'):], int) - if not frame_number in task_data.frame_info: + if not frame_number in instance_data.frame_info: raise CvatImportError("Could not match item id: " "'%s' with any task frame" % item.id) return frame_number -def find_dataset_root(dm_dataset, instance_data: Union[TaskData, ProjectData]): +def find_dataset_root(dm_dataset, instance_data: Union[ProjectData, CommonData]): longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts), default=None) if longest_path is None: @@ -1503,7 +1653,7 @@ def find_dataset_root(dm_dataset, instance_data: Union[TaskData, ProjectData]): prefix = prefix[:-1] return prefix -def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[TaskData, ProjectData]): +def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectData, CommonData]): if len(dm_dataset) == 0: return diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index ede5b67c1ae8..5d286e22752d 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -21,7 +21,7 @@ from datumaro.util.image import Image from defusedxml import ElementTree -from cvat.apps.dataset_manager.bindings import (ProjectData, TaskData, +from cvat.apps.dataset_manager.bindings import (ProjectData, CommonData, get_defaulted_subset, import_dm_annotations, match_dm_item) @@ -984,11 +984,11 @@ def dump_track(idx, track): counter += 1 for shape in annotations.shapes: - frame_step = annotations.frame_step if isinstance(annotations, TaskData) else annotations.frame_step[shape.task_id] - if isinstance(annotations, TaskData): - stop_frame = int(annotations.meta['task']['stop_frame']) + frame_step = annotations.frame_step if not isinstance(annotations, ProjectData) else annotations.frame_step[shape.task_id] + if not isinstance(annotations, ProjectData): + stop_frame = int(annotations.meta[annotations.META_FIELD]['stop_frame']) else: - task_meta = list(filter(lambda task: int(task[1]['id']) == shape.task_id, annotations.meta['project']['tasks']))[0][1] + task_meta = list(filter(lambda task: int(task[1]['id']) == shape.task_id, annotations.meta[annotations.META_FIELD]['tasks']))[0][1] stop_frame = int(task_meta['stop_frame']) track = { 'label': shape.label, @@ -1102,7 +1102,7 @@ def load_anno(file_object, annotations): attributes={'frame': el.attrib['id']}, image=el.attrib['name'] ), - task_data=annotations + instance_data=annotations )) elif el.tag in supported_shapes and (track is not None or image_is_opened): if shape and shape['type'] == 'skeleton': @@ -1258,10 +1258,10 @@ def load_anno(file_object, annotations): tag = None el.clear() -def dump_task_anno(dst_file, task_data, callback): +def dump_task_or_job_anno(dst_file, instance_data, callback): dumper = create_xml_dumper(dst_file) dumper.open_document() - callback(dumper, task_data) + callback(dumper, instance_data) dumper.close_document() def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callback: Callable): @@ -1270,33 +1270,34 @@ def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callb callback(dumper, project_data) dumper.close_document() -def dump_media_files(task_data: TaskData, img_dir: str, project_data: ProjectData = None): +def dump_media_files(instance_data: CommonData, img_dir: str, project_data: ProjectData = None): ext = '' - if task_data.meta['task']['mode'] == 'interpolation': + if instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation': ext = FrameProvider.VIDEO_FRAME_EXT - frame_provider = FrameProvider(task_data.db_task.data) + frame_provider = FrameProvider(instance_data.db_data) frames = frame_provider.get_frames( + instance_data.start, instance_data.stop, frame_provider.Quality.ORIGINAL, frame_provider.Type.BUFFER) - for frame_id, (frame_data, _) in enumerate(frames): - if (project_data is not None and (task_data.db_task.id, frame_id) in project_data.deleted_frames) \ - or frame_id in task_data.deleted_frames: + for frame_id, (frame_data, _) in zip(instance_data.rel_range, frames): + if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \ + or frame_id in instance_data.deleted_frames: continue - frame_name = task_data.frame_info[frame_id]['path'] if project_data is None \ - else project_data.frame_info[(task_data.db_task.id, frame_id)]['path'] + frame_name = instance_data.frame_info[frame_id]['path'] if project_data is None \ + else project_data.frame_info[(instance_data.db_instance.id, frame_id)]['path'] img_path = osp.join(img_dir, frame_name + ext) os.makedirs(osp.dirname(img_path), exist_ok=True) with open(img_path, 'wb') as f: f.write(frame_data.getvalue()) -def _export_task(dst_file, task_data, anno_callback, save_images=False): +def _export_task_or_job(dst_file, instance_data, anno_callback, save_images=False): with TemporaryDirectory() as temp_dir: with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f: - dump_task_anno(f, task_data, anno_callback) + dump_task_or_job_anno(f, instance_data, anno_callback) if save_images: - dump_media_files(task_data, osp.join(temp_dir, 'images')) + dump_media_files(instance_data, osp.join(temp_dir, 'images')) make_zip_archive(temp_dir, dst_file) @@ -1307,7 +1308,7 @@ def _export_project(dst_file: str, project_data: ProjectData, anno_callback: Cal if save_images: for task_data in project_data.task_data: - subset = get_defaulted_subset(task_data.db_task.subset, project_data.subsets) + subset = get_defaulted_subset(task_data.db_instance.subset, project_data.subsets) subset_dir = osp.join(temp_dir, 'images', subset) os.makedirs(subset_dir, exist_ok=True) dump_media_files(task_data, subset_dir, project_data) @@ -1320,7 +1321,7 @@ def _export_video(dst_file, instance_data, save_images=False): _export_project(dst_file, instance_data, anno_callback=dump_as_cvat_interpolation, save_images=save_images) else: - _export_task(dst_file, instance_data, + _export_task_or_job(dst_file, instance_data, anno_callback=dump_as_cvat_interpolation, save_images=save_images) @exporter(name='CVAT for images', ext='ZIP', version='1.1') @@ -1329,7 +1330,7 @@ def _export_images(dst_file, instance_data, save_images=False): _export_project(dst_file, instance_data, anno_callback=dump_as_cvat_annotation, save_images=save_images) else: - _export_task(dst_file, instance_data, + _export_task_or_job(dst_file, instance_data, anno_callback=dump_as_cvat_annotation, save_images=save_images) @importer(name='CVAT', ext='XML, ZIP', version='1.1') diff --git a/cvat/apps/dataset_manager/formats/kitti.py b/cvat/apps/dataset_manager/formats/kitti.py index d4d2fd8bd31a..d3296b8a79bb 100644 --- a/cvat/apps/dataset_manager/formats/kitti.py +++ b/cvat/apps/dataset_manager/formats/kitti.py @@ -1,4 +1,5 @@ # Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -9,8 +10,7 @@ from datumaro.plugins.kitti_format.format import KittiPath, write_label_map from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, - ProjectData, import_dm_annotations) +from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, import_dm_annotations) from cvat.apps.dataset_manager.util import make_zip_archive from .transformations import RotatedBoxesToPolygons @@ -45,8 +45,7 @@ def _import(src_file, instance_data): write_label_map(color_map_path, color_map) dataset = Dataset.import_from(tmp_dir, format='kitti', env=dm_env) - labels_meta = instance_data.meta['project']['labels'] \ - if isinstance(instance_data, ProjectData) else instance_data.meta['task']['labels'] + labels_meta = instance_data.meta[instance_data.META_FIELD]['labels'] if 'background' not in [label['name'] for _, label in labels_meta]: dataset.filter('/item/annotation[label != "background"]', filter_annotations=True) diff --git a/cvat/apps/dataset_manager/formats/mot.py b/cvat/apps/dataset_manager/formats/mot.py index fae76fb44e3e..031aa3a54573 100644 --- a/cvat/apps/dataset_manager/formats/mot.py +++ b/cvat/apps/dataset_manager/formats/mot.py @@ -1,4 +1,5 @@ # Copyright (C) 2019-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -13,13 +14,15 @@ from .registry import dm_env, exporter, importer -def _import_task(dataset, task_data): +def _import_to_task(dataset, instance_data): tracks = {} label_cat = dataset.categories()[datumaro.AnnotationType.label] for item in dataset: - frame_number = int(item.id) - 1 # NOTE: MOT frames start from 1 - frame_number = task_data.abs_frame_id(frame_number) + # NOTE: MOT frames start from 1 + # job has an offset, for task offset is 0 + frame_number = int(item.id) - 1 + instance_data.start + frame_number = instance_data.abs_frame_id(frame_number) for ann in item.annotations: if ann.type != datumaro.AnnotationType.bbox: @@ -28,7 +31,7 @@ def _import_task(dataset, task_data): track_id = ann.attributes.get('track_id') if track_id is None: # Extension. Import regular boxes: - task_data.add_shape(task_data.LabeledShape( + instance_data.add_shape(instance_data.LabeledShape( type='rectangle', label=label_cat.items[ann.label].name, points=ann.points, @@ -41,7 +44,7 @@ def _import_task(dataset, task_data): )) continue - shape = task_data.TrackedShape( + shape = instance_data.TrackedShape( type='rectangle', points=ann.points, occluded=ann.attributes.get('occluded') is True, @@ -55,7 +58,7 @@ def _import_task(dataset, task_data): # build trajectories as lists of shapes in track dict if track_id not in tracks: - tracks[track_id] = task_data.Track( + tracks[track_id] = instance_data.Track( label_cat.items[ann.label].name, 0, 'manual', []) tracks[track_id].shapes.append(shape) @@ -67,10 +70,10 @@ def _import_task(dataset, task_data): prev_shape_idx = 0 prev_shape = track.shapes[0] for shape in track.shapes[1:]: - has_skip = task_data.frame_step < shape.frame - prev_shape.frame + has_skip = instance_data.frame_step < shape.frame - prev_shape.frame if has_skip and not prev_shape.outside: prev_shape = prev_shape._replace(outside=True, - frame=prev_shape.frame + task_data.frame_step) + frame=prev_shape.frame + instance_data.frame_step) prev_shape_idx += 1 track.shapes.insert(prev_shape_idx, prev_shape) prev_shape = shape @@ -78,12 +81,12 @@ def _import_task(dataset, task_data): # Append a shape with outside=True to finish the track last_shape = track.shapes[-1] - if last_shape.frame + task_data.frame_step <= \ - int(task_data.meta['task']['stop_frame']): + if last_shape.frame + instance_data.frame_step <= \ + int(instance_data.meta[instance_data.META_FIELD]['stop_frame']): track.shapes.append(last_shape._replace(outside=True, - frame=last_shape.frame + task_data.frame_step) + frame=last_shape.frame + instance_data.frame_step) ) - task_data.add_track(track) + instance_data.add_track(track) @exporter(name='MOT', ext='ZIP', version='1.1') @@ -107,7 +110,7 @@ def _import(src_file, instance_data, load_data_callback=None): # Dirty way to determine instance type to avoid circular dependency if hasattr(instance_data, '_db_project'): for sub_dataset, task_data in instance_data.split_dataset(dataset): - _import_task(sub_dataset, task_data) + _import_to_task(sub_dataset, task_data) else: - _import_task(dataset, instance_data) + _import_to_task(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/mots.py b/cvat/apps/dataset_manager/formats/mots.py index 855f043f1a07..90f17184b453 100644 --- a/cvat/apps/dataset_manager/formats/mots.py +++ b/cvat/apps/dataset_manager/formats/mots.py @@ -1,4 +1,5 @@ # Copyright (C) 2019-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -22,16 +23,16 @@ def transform_item(self, item): return item.wrap(annotations=[a for a in item.annotations if 'track_id' in a.attributes]) -def _import_task(dataset, task_data): +def _import_to_task(dataset, instance_data): tracks = {} label_cat = dataset.categories()[AnnotationType.label] - root_hint = find_dataset_root(dataset, task_data) + root_hint = find_dataset_root(dataset, instance_data) shift = 0 for item in dataset: - frame_number = task_data.abs_frame_id( - match_dm_item(item, task_data, root_hint=root_hint)) + frame_number = instance_data.abs_frame_id( + match_dm_item(item, instance_data, root_hint=root_hint)) track_ids = set() @@ -49,7 +50,7 @@ def _import_task(dataset, task_data): else: track_ids.add(track_id) - shape = task_data.TrackedShape( + shape = instance_data.TrackedShape( type='polygon', points=ann.points, occluded=ann.attributes.get('occluded') is True, @@ -64,7 +65,7 @@ def _import_task(dataset, task_data): # build trajectories as lists of shapes in track dict if track_id not in tracks: - tracks[track_id] = task_data.Track( + tracks[track_id] = instance_data.Track( label_cat.items[ann.label].name, 0, 'manual', []) tracks[track_id].shapes.append(shape) @@ -75,10 +76,10 @@ def _import_task(dataset, task_data): prev_shape_idx = 0 prev_shape = track.shapes[0] for shape in track.shapes[1:]: - has_skip = task_data.frame_step < shape.frame - prev_shape.frame + has_skip = instance_data.frame_step < shape.frame - prev_shape.frame if has_skip and not prev_shape.outside: prev_shape = prev_shape._replace(outside=True, - frame=prev_shape.frame + task_data.frame_step) + frame=prev_shape.frame + instance_data.frame_step) prev_shape_idx += 1 track.shapes.insert(prev_shape_idx, prev_shape) prev_shape = shape @@ -86,12 +87,12 @@ def _import_task(dataset, task_data): # Append a shape with outside=True to finish the track last_shape = track.shapes[-1] - if last_shape.frame + task_data.frame_step <= \ - int(task_data.meta['task']['stop_frame']): + if last_shape.frame + instance_data.frame_step <= \ + int(instance_data.meta[instance_data.META_FIELD]['stop_frame']): track.shapes.append(last_shape._replace(outside=True, - frame=last_shape.frame + task_data.frame_step) + frame=last_shape.frame + instance_data.frame_step) ) - task_data.add_track(track) + instance_data.add_track(track) @exporter(name='MOTS PNG', ext='ZIP', version='1.0') def _export(dst_file, instance_data, save_images=False): @@ -120,7 +121,7 @@ def _import(src_file, instance_data, load_data_callback=None): # Dirty way to determine instance type to avoid circular dependency if hasattr(instance_data, '_db_project'): for sub_dataset, task_data in instance_data.split_dataset(dataset): - _import_task(sub_dataset, task_data) + _import_to_task(sub_dataset, task_data) else: - _import_task(dataset, instance_data) + _import_to_task(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/pascal_voc.py b/cvat/apps/dataset_manager/formats/pascal_voc.py index acbfc9675a83..d965e25a5e44 100644 --- a/cvat/apps/dataset_manager/formats/pascal_voc.py +++ b/cvat/apps/dataset_manager/formats/pascal_voc.py @@ -1,4 +1,5 @@ # Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -11,8 +12,7 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, - ProjectData, import_dm_annotations) +from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, import_dm_annotations) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer @@ -36,8 +36,7 @@ def _import(src_file, instance_data, load_data_callback=None): # put label map from the task if not present labelmap_file = osp.join(tmp_dir, 'labelmap.txt') if not osp.isfile(labelmap_file): - labels_meta = instance_data.meta['project']['labels'] \ - if isinstance(instance_data, ProjectData) else instance_data.meta['task']['labels'] + labels_meta = instance_data.meta[instance_data.META_FIELD]['labels'] labels = (label['name'] + ':::' for _, label in labels_meta) with open(labelmap_file, 'w') as f: f.write('\n'.join(labels)) diff --git a/cvat/apps/dataset_manager/formats/utils.py b/cvat/apps/dataset_manager/formats/utils.py index 50b5449b80c5..7811fbbfc902 100644 --- a/cvat/apps/dataset_manager/formats/utils.py +++ b/cvat/apps/dataset_manager/formats/utils.py @@ -49,8 +49,7 @@ def hex2rgb(color): return tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) def make_colormap(instance_data): - instance_name = 'project' if 'project' in instance_data.meta.keys() else 'task' - labels = [label for _, label in instance_data.meta[instance_name]['labels']] + labels = [label for _, label in instance_data.meta[instance_data.META_FIELD]['labels']] label_names = [label['name'] for label in labels] if 'background' not in label_names: diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index ea486b68fddf..8210e331b78d 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -16,7 +16,7 @@ from cvat.apps.profiler import silk_profile from .annotation import AnnotationIR, AnnotationManager -from .bindings import TaskData +from .bindings import TaskData, JobData from .formats.registry import make_exporter, make_importer from .util import bulk_create @@ -553,24 +553,24 @@ def data(self): return self.ir_data.data def export(self, dst_file, exporter, host='', **options): - task_data = TaskData( + job_data = JobData( annotation_ir=self.ir_data, - db_task=self.db_job.segment.task, + db_job=self.db_job, host=host, ) - exporter(dst_file, task_data, **options) + exporter(dst_file, job_data, **options) def import_annotations(self, src_file, importer): - task_data = TaskData( + job_data = JobData( annotation_ir=AnnotationIR(), - db_task=self.db_job.segment.task, + db_job=self.db_job, create_callback=self.create, ) self.delete() - importer(src_file, task_data) + importer(src_file, job_data) - self.create(task_data.data.slice(self.start_frame, self.stop_frame).serialize()) + self.create(job_data.data.slice(self.start_frame, self.stop_frame).serialize()) class TaskAnnotation: def __init__(self, pk): diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 38ffe236d83c..ba8b090fb7dd 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -21,7 +21,7 @@ import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.annotation import AnnotationIR -from cvat.apps.dataset_manager.bindings import (CvatTaskDataExtractor, +from cvat.apps.dataset_manager.bindings import (CvatTaskOrJobDataExtractor, TaskData, find_dataset_root) from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.util import make_zip_archive @@ -417,7 +417,7 @@ def test_can_skip_outside(self): task_ann.init_from_db() task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task["id"])) - extractor = CvatTaskDataExtractor(task_data) + extractor = CvatTaskOrJobDataExtractor(task_data) dm_dataset = datumaro.components.project.Dataset.from_extractors(extractor) self.assertEqual(4, len(dm_dataset.get("image_1").annotations)) diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index 140d17177252..1f9c1bf896be 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -1,4 +1,5 @@ # Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -22,7 +23,7 @@ from rest_framework.test import APIClient, APITestCase import cvat.apps.dataset_manager as dm -from cvat.apps.dataset_manager.bindings import CvatTaskDataExtractor, TaskData +from cvat.apps.dataset_manager.bindings import CvatTaskOrJobDataExtractor, TaskData from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.engine.models import Task @@ -185,7 +186,7 @@ def _get_data_from_task(self, task_id, include_images): task_ann = TaskAnnotation(task_id) task_ann.init_from_db() task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task_id)) - extractor = CvatTaskDataExtractor(task_data, include_images=include_images) + extractor = CvatTaskOrJobDataExtractor(task_data, include_images=include_images) return Dataset.from_extractors(extractor) def _get_request_with_data(self, path, data, user): diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 3871a058f5c2..15e2a9eab68e 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -1,4 +1,5 @@ # Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -182,6 +183,6 @@ def get_frame(self, frame_number, quality=Quality.ORIGINAL, return (frame, self.VIDEO_FRAME_MIME) return (frame, mimetypes.guess_type(frame_name)[0]) - def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER): - for idx in range(self._db_data.size): + def get_frames(self, start_frame, stop_frame, quality=Quality.ORIGINAL, out_type=Type.BUFFER): + for idx in range(start_frame, stop_frame): yield self.get_frame(idx, quality=quality, out_type=out_type)