Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix empty data loading in Insights #728

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions captum/insights/attr_vis/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
from collections import namedtuple
from itertools import cycle
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -199,6 +200,7 @@ class scores.
self._outputs: List[VisualizationOutput] = []
self._config = FilterConfig(prediction="all", classes=[], num_examples=4)
self._dataset_iter = iter(dataset)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't your PR be simpler to construct this as cycle(iter(dataset)) since it seems you're effectively reimplementing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be much simpler that way, but my concern is about memory usage. Per the docs, it seems that cycle

may require significant auxiliary storage (depending on the length of the iterable)

Which suggests to me that it stores the entire thing in an array and then loops around. If the dataset is huge, then I think this could be an issue, because every batch would get stored in memory. That's why I made the cache only store a few batches instead, rather than using a cycle over the entire dataset.

However, if you don't think the memory is an issue, I could do it this way instead. It would probably only be an issue if somebody requested a lot of attributions.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for clarifying. Probably best to assume memory will be an issue.

self._dataset_cache: List[Batch] = []

def _calculate_attribution_from_cache(
self, input_index: int, model_index: int, target: Optional[Tensor]
Expand Down Expand Up @@ -439,7 +441,22 @@ def _calculate_vis_output(
return results if results else None

def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
batch_data = next(self._dataset_iter)
# If we run out of new betches, then we need to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

betches -> batches ?

# display data which was already shown before.
# However, since the dataset given to us is a generator,
# we can't reset it to return to the beginning.
# Because of this, we store a small cache of stale
# data, and iterate on it after the main generator
# stops returning new batches.
try:
batch_data = next(self._dataset_iter)
self._dataset_cache.append(batch_data)
if len(self._dataset_cache) > self._config.num_examples:
self._dataset_cache.pop(0)
except StopIteration:
self._dataset_iter = cycle(self._dataset_cache)
batch_data = next(self._dataset_iter)

vis_outputs = []

# Type ignore for issue with passing union to function taking generic
Expand All @@ -465,10 +482,7 @@ def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
def visualize(self):
self._outputs = []
while len(self._outputs) < self._config.num_examples:
try:
self._outputs.extend(self._get_outputs())
except StopIteration:
break
self._outputs.extend(self._get_outputs())
return [o[0] for o in self._outputs]

def get_insights_config(self):
Expand Down