Skip to content

Commit

Permalink
[Datumaro] Fix TFrecord converter constructor (cvat-ai#993)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored and Chris Lee-Messer committed Mar 5, 2020
1 parent a8f186f commit 8b301e1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
3 changes: 1 addition & 2 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ def frame(self, request, pk, frame):
@action(detail=True, methods=['GET'], serializer_class=None,
url_path='dataset')
def dataset_export(self, request, pk):

db_task = self.get_object()

action = request.query_params.get("action", "")
Expand All @@ -611,7 +610,7 @@ def dataset_export(self, request, pk):
raise serializers.ValidationError(
"Unexpected parameter 'format' specified for the request")

rq_id = "task_dataset_export.{}.{}".format(pk, dst_format)
rq_id = "/api/v1/tasks/{}/dataset/{}".format(pk, dst_format)
queue = django_rq.get_queue("default")

rq_job = queue.fetch_job(rq_id)
Expand Down
25 changes: 22 additions & 3 deletions datumaro/datumaro/components/converters/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,27 @@ def float_list_feature(value):
return tf_example

class DetectionApiConverter:
def __init__(self, save_images=True):
self.save_images = save_images
def __init__(self, save_images=False, cmdline_args=None):
super().__init__()

self._save_images = save_images

if cmdline_args is not None:
options = self._parse_cmdline(cmdline_args)
for k, v in options.items():
if hasattr(self, '_' + str(k)):
setattr(self, '_' + str(k), v)

@classmethod
def build_cmdline_parser(cls, parser=None):
import argparse
if not parser:
parser = argparse.ArgumentParser()

parser.add_argument('--save-images', action='store_true',
help="Save images (default: %(default)s)")

return parser

def __call__(self, extractor, save_dir):
tf = _import_tf()
Expand Down Expand Up @@ -141,6 +160,6 @@ def __call__(self, extractor, save_dir):
item,
get_label=get_label,
get_label_id=map_label_id,
save_images=self.save_images,
save_images=self._save_images,
)
writer.write(tf_example.SerializeToString())
3 changes: 2 additions & 1 deletion datumaro/tests/test_tfrecord_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def categories(self):

with TestDir() as test_dir:
self._test_can_save_and_load(
TestExtractor(), DetectionApiConverter(), test_dir)
TestExtractor(), DetectionApiConverter(save_images=True),
test_dir)

def test_labelmap_parsing(self):
text = """
Expand Down

0 comments on commit 8b301e1

Please sign in to comment.