Skip to content

Commit 4f64c86

Browse files
committed
Fix empty data loading in Insights
1 parent f1346fc commit 4f64c86

File tree

1 file changed

+19
-5
lines changed
  • captum/insights/attr_vis

1 file changed

+19
-5
lines changed

captum/insights/attr_vis/app.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
from collections import namedtuple
3+
from itertools import cycle
34
from typing import (
45
Any,
56
Callable,
@@ -199,6 +200,7 @@ class scores.
199200
self._outputs: List[VisualizationOutput] = []
200201
self._config = FilterConfig(prediction="all", classes=[], num_examples=4)
201202
self._dataset_iter = iter(dataset)
203+
self._dataset_cache = []
202204

203205
def _calculate_attribution_from_cache(
204206
self, input_index: int, model_index: int, target: Optional[Tensor]
@@ -439,7 +441,22 @@ def _calculate_vis_output(
439441
return results if results else None
440442

441443
def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
442-
batch_data = next(self._dataset_iter)
444+
# If we run out of new betches, then we need to
445+
# display data which was already shown before.
446+
# However, since the dataset given to us is a generator,
447+
# we can't reset it to return to the beginning.
448+
# Because of this, we store a small cache of stale
449+
# data, and iterate on it after the main generator
450+
# stops returning new batches.
451+
try:
452+
batch_data = next(self._dataset_iter)
453+
self._dataset_cache.append(batch_data)
454+
if len(self._dataset_cache) > self._config.num_examples:
455+
self._dataset_cache.pop(0)
456+
except StopIteration:
457+
self._dataset_iter = cycle(self._dataset_cache)
458+
batch_data = next(self._dataset_iter)
459+
443460
vis_outputs = []
444461

445462
# Type ignore for issue with passing union to function taking generic
@@ -465,10 +482,7 @@ def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
465482
def visualize(self):
466483
self._outputs = []
467484
while len(self._outputs) < self._config.num_examples:
468-
try:
469-
self._outputs.extend(self._get_outputs())
470-
except StopIteration:
471-
break
485+
self._outputs.extend(self._get_outputs())
472486
return [o[0] for o in self._outputs]
473487

474488
def get_insights_config(self):

0 commit comments

Comments
 (0)