Skip to content

Commit 57363b8

Browse files
committed
Refactor tests
1 parent 9e9bb24 commit 57363b8

File tree

1 file changed

+14
-42
lines changed

1 file changed

+14
-42
lines changed

datumaro/tests/test_project.py

+14-42
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,11 @@ def test_can_have_project_source(self):
134134

135135
def test_can_batch_launch_custom_model(self):
136136
class TestExtractor(Extractor):
137-
def __init__(self, url, n=0):
138-
super().__init__(length=n)
139-
self.n = n
140-
141137
def __iter__(self):
142-
for i in range(self.n):
138+
for i in range(5):
143139
yield DatasetItem(id=i, subset='train', image=i)
144140

145141
class TestLauncher(Launcher):
146-
def __init__(self, **kwargs):
147-
pass
148-
149142
def launch(self, inputs):
150143
for i, inp in enumerate(inputs):
151144
yield [ LabelObject(attributes={'idx': i, 'data': inp}) ]
@@ -157,7 +150,7 @@ def launch(self, inputs):
157150
project.env.launchers.register(launcher_name, TestLauncher)
158151
project.add_model(model_name, { 'launcher': launcher_name })
159152
model = project.make_executable_model(model_name)
160-
extractor = TestExtractor('', n=5)
153+
extractor = TestExtractor()
161154

162155
batch_size = 3
163156
executor = InferenceWrapper(extractor, model, batch_size=batch_size)
@@ -171,27 +164,20 @@ def launch(self, inputs):
171164

172165
def test_can_do_transform_with_custom_model(self):
173166
class TestExtractorSrc(Extractor):
174-
def __init__(self, url, n=2):
175-
super().__init__(length=n)
176-
self.n = n
177-
178167
def __iter__(self):
179-
for i in range(self.n):
168+
for i in range(2):
180169
yield DatasetItem(id=i, subset='train', image=i,
181170
annotations=[ LabelObject(i) ])
182171

183172
class TestLauncher(Launcher):
184-
def __init__(self, **kwargs):
185-
pass
186-
187173
def launch(self, inputs):
188174
for inp in inputs:
189175
yield [ LabelObject(inp) ]
190176

191177
class TestConverter(Converter):
192178
def __call__(self, extractor, save_dir):
193179
for item in extractor:
194-
with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f:
180+
with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f:
195181
f.write(str(item.subset) + '\n')
196182
f.write(str(item.annotations[0].label) + '\n')
197183

@@ -204,8 +190,8 @@ def __iter__(self):
204190
for path in self.items:
205191
with open(path, 'r') as f:
206192
index = osp.splitext(osp.basename(path))[0]
207-
subset = f.readline()[:-1]
208-
label = int(f.readline()[:-1])
193+
subset = f.readline().strip()
194+
label = int(f.readline().strip())
209195
assert subset == 'train'
210196
yield DatasetItem(id=index, subset=subset,
211197
annotations=[ LabelObject(label) ])
@@ -261,12 +247,8 @@ def __iter__(self):
261247

262248
def test_project_filter_can_be_applied(self):
263249
class TestExtractor(Extractor):
264-
def __init__(self, url, n=10):
265-
super().__init__(length=n)
266-
self.n = n
267-
268250
def __iter__(self):
269-
for i in range(self.n):
251+
for i in range(10):
270252
yield DatasetItem(id=i, subset='train')
271253

272254
e_type = 'type'
@@ -331,30 +313,23 @@ def test_project_compound_child_can_be_modified_recursively(self):
331313
self.assertEqual(1, len(dataset.sources['child2']))
332314

333315
def test_project_can_merge_item_annotations(self):
334-
class TestExtractor(Extractor):
335-
def __init__(self, url, v=None):
336-
super().__init__()
337-
self.v = v
338-
316+
class TestExtractor1(Extractor):
339317
def __iter__(self):
340-
v1_item = DatasetItem(id=1, subset='train', annotations=[
318+
yield DatasetItem(id=1, subset='train', annotations=[
341319
LabelObject(2, id=3),
342320
LabelObject(3, attributes={ 'x': 1 }),
343321
])
344322

345-
v2_item = DatasetItem(id=1, subset='train', annotations=[
323+
class TestExtractor2(Extractor):
324+
def __iter__(self):
325+
yield DatasetItem(id=1, subset='train', annotations=[
346326
LabelObject(3, attributes={ 'x': 1 }),
347327
LabelObject(4, id=4),
348328
])
349329

350-
if self.v == 1:
351-
yield v1_item
352-
else:
353-
yield v2_item
354-
355330
project = Project()
356-
project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1))
357-
project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2))
331+
project.env.extractors.register('t1', TestExtractor1)
332+
project.env.extractors.register('t2', TestExtractor2)
358333
project.add_source('source1', { 'format': 't1' })
359334
project.add_source('source2', { 'format': 't2' })
360335

@@ -494,9 +469,6 @@ def test_can_produce_multilayer_config_from_dict(self):
494469
class ExtractorTest(TestCase):
495470
def test_custom_extractor_can_be_created(self):
496471
class CustomExtractor(Extractor):
497-
def __init__(self, url):
498-
super().__init__()
499-
500472
def __iter__(self):
501473
return iter([
502474
DatasetItem(id=0, subset='train'),

0 commit comments

Comments
 (0)