@@ -134,18 +134,11 @@ def test_can_have_project_source(self):
134
134
135
135
def test_can_batch_launch_custom_model (self ):
136
136
class TestExtractor (Extractor ):
137
- def __init__ (self , url , n = 0 ):
138
- super ().__init__ (length = n )
139
- self .n = n
140
-
141
137
def __iter__ (self ):
142
- for i in range (self . n ):
138
+ for i in range (5 ):
143
139
yield DatasetItem (id = i , subset = 'train' , image = i )
144
140
145
141
class TestLauncher (Launcher ):
146
- def __init__ (self , ** kwargs ):
147
- pass
148
-
149
142
def launch (self , inputs ):
150
143
for i , inp in enumerate (inputs ):
151
144
yield [ LabelObject (attributes = {'idx' : i , 'data' : inp }) ]
@@ -157,7 +150,7 @@ def launch(self, inputs):
157
150
project .env .launchers .register (launcher_name , TestLauncher )
158
151
project .add_model (model_name , { 'launcher' : launcher_name })
159
152
model = project .make_executable_model (model_name )
160
- extractor = TestExtractor ('' , n = 5 )
153
+ extractor = TestExtractor ()
161
154
162
155
batch_size = 3
163
156
executor = InferenceWrapper (extractor , model , batch_size = batch_size )
@@ -171,27 +164,20 @@ def launch(self, inputs):
171
164
172
165
def test_can_do_transform_with_custom_model (self ):
173
166
class TestExtractorSrc (Extractor ):
174
- def __init__ (self , url , n = 2 ):
175
- super ().__init__ (length = n )
176
- self .n = n
177
-
178
167
def __iter__ (self ):
179
- for i in range (self . n ):
168
+ for i in range (2 ):
180
169
yield DatasetItem (id = i , subset = 'train' , image = i ,
181
170
annotations = [ LabelObject (i ) ])
182
171
183
172
class TestLauncher (Launcher ):
184
- def __init__ (self , ** kwargs ):
185
- pass
186
-
187
173
def launch (self , inputs ):
188
174
for inp in inputs :
189
175
yield [ LabelObject (inp ) ]
190
176
191
177
class TestConverter (Converter ):
192
178
def __call__ (self , extractor , save_dir ):
193
179
for item in extractor :
194
- with open (osp .join (save_dir , '%s.txt' % item .id ), 'w+ ' ) as f :
180
+ with open (osp .join (save_dir , '%s.txt' % item .id ), 'w' ) as f :
195
181
f .write (str (item .subset ) + '\n ' )
196
182
f .write (str (item .annotations [0 ].label ) + '\n ' )
197
183
@@ -204,8 +190,8 @@ def __iter__(self):
204
190
for path in self .items :
205
191
with open (path , 'r' ) as f :
206
192
index = osp .splitext (osp .basename (path ))[0 ]
207
- subset = f .readline ()[: - 1 ]
208
- label = int (f .readline ()[: - 1 ] )
193
+ subset = f .readline (). strip ()
194
+ label = int (f .readline (). strip () )
209
195
assert subset == 'train'
210
196
yield DatasetItem (id = index , subset = subset ,
211
197
annotations = [ LabelObject (label ) ])
@@ -261,12 +247,8 @@ def __iter__(self):
261
247
262
248
def test_project_filter_can_be_applied (self ):
263
249
class TestExtractor (Extractor ):
264
- def __init__ (self , url , n = 10 ):
265
- super ().__init__ (length = n )
266
- self .n = n
267
-
268
250
def __iter__ (self ):
269
- for i in range (self . n ):
251
+ for i in range (10 ):
270
252
yield DatasetItem (id = i , subset = 'train' )
271
253
272
254
e_type = 'type'
@@ -331,30 +313,23 @@ def test_project_compound_child_can_be_modified_recursively(self):
331
313
self .assertEqual (1 , len (dataset .sources ['child2' ]))
332
314
333
315
def test_project_can_merge_item_annotations (self ):
334
- class TestExtractor (Extractor ):
335
- def __init__ (self , url , v = None ):
336
- super ().__init__ ()
337
- self .v = v
338
-
316
+ class TestExtractor1 (Extractor ):
339
317
def __iter__ (self ):
340
- v1_item = DatasetItem (id = 1 , subset = 'train' , annotations = [
318
+ yield DatasetItem (id = 1 , subset = 'train' , annotations = [
341
319
LabelObject (2 , id = 3 ),
342
320
LabelObject (3 , attributes = { 'x' : 1 }),
343
321
])
344
322
345
- v2_item = DatasetItem (id = 1 , subset = 'train' , annotations = [
323
+ class TestExtractor2 (Extractor ):
324
+ def __iter__ (self ):
325
+ yield DatasetItem (id = 1 , subset = 'train' , annotations = [
346
326
LabelObject (3 , attributes = { 'x' : 1 }),
347
327
LabelObject (4 , id = 4 ),
348
328
])
349
329
350
- if self .v == 1 :
351
- yield v1_item
352
- else :
353
- yield v2_item
354
-
355
330
project = Project ()
356
- project .env .extractors .register ('t1' , lambda p : TestExtractor ( p , v = 1 ) )
357
- project .env .extractors .register ('t2' , lambda p : TestExtractor ( p , v = 2 ) )
331
+ project .env .extractors .register ('t1' , TestExtractor1 )
332
+ project .env .extractors .register ('t2' , TestExtractor2 )
358
333
project .add_source ('source1' , { 'format' : 't1' })
359
334
project .add_source ('source2' , { 'format' : 't2' })
360
335
@@ -494,9 +469,6 @@ def test_can_produce_multilayer_config_from_dict(self):
494
469
class ExtractorTest (TestCase ):
495
470
def test_custom_extractor_can_be_created (self ):
496
471
class CustomExtractor (Extractor ):
497
- def __init__ (self , url ):
498
- super ().__init__ ()
499
-
500
472
def __iter__ (self ):
501
473
return iter ([
502
474
DatasetItem (id = 0 , subset = 'train' ),
0 commit comments