1
1
#!/usr/bin/env python3
2
2
from collections import namedtuple
3
+ from itertools import cycle
3
4
from typing import (
4
5
Any ,
5
6
Callable ,
@@ -199,6 +200,7 @@ class scores.
199
200
self ._outputs : List [VisualizationOutput ] = []
200
201
self ._config = FilterConfig (prediction = "all" , classes = [], num_examples = 4 )
201
202
self ._dataset_iter = iter (dataset )
203
+ self ._dataset_cache = []
202
204
203
205
def _calculate_attribution_from_cache (
204
206
self , input_index : int , model_index : int , target : Optional [Tensor ]
@@ -439,7 +441,22 @@ def _calculate_vis_output(
439
441
return results if results else None
440
442
441
443
def _get_outputs (self ) -> List [Tuple [List [VisualizationOutput ], SampleCache ]]:
442
- batch_data = next (self ._dataset_iter )
444
+ # If we run out of new betches, then we need to
445
+ # display data which was already shown before.
446
+ # However, since the dataset given to us is a generator,
447
+ # we can't reset it to return to the beginning.
448
+ # Because of this, we store a small cache of stale
449
+ # data, and iterate on it after the main generator
450
+ # stops returning new batches.
451
+ try :
452
+ batch_data = next (self ._dataset_iter )
453
+ self ._dataset_cache .append (batch_data )
454
+ if len (self ._dataset_cache ) > self ._config .num_examples :
455
+ self ._dataset_cache .pop (0 )
456
+ except StopIteration :
457
+ self ._dataset_iter = cycle (self ._dataset_cache )
458
+ batch_data = next (self ._dataset_iter )
459
+
443
460
vis_outputs = []
444
461
445
462
# Type ignore for issue with passing union to function taking generic
@@ -465,10 +482,7 @@ def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
465
482
def visualize (self ):
466
483
self ._outputs = []
467
484
while len (self ._outputs ) < self ._config .num_examples :
468
- try :
469
- self ._outputs .extend (self ._get_outputs ())
470
- except StopIteration :
471
- break
485
+ self ._outputs .extend (self ._get_outputs ())
472
486
return [o [0 ] for o in self ._outputs ]
473
487
474
488
def get_insights_config (self ):
0 commit comments