Skip to content

Commit 7c16b42

Browse files
committed
add options for extra loops and the cmap value
1 parent bd87691 commit 7c16b42

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

captum/attr/_core/latent_shift.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def generate_video(
247247
temp_path: str = "/tmp/gifsplanation",
248248
show: bool = True,
249249
verbose: bool = True,
250+
extra_loops: int = 0,
251+
cmap: str = None,
250252
):
251253
"""Generate a video from the generated images.
252254
@@ -260,6 +262,10 @@ def generate_video(
260262
temp_path: A temp path to write images.
261263
show: To try and show the video in a jupyter notebook.
262264
verbose: True to print debug text
265+
extra_loops: The video does one loop by default. This will repeat
266+
those loops to make it easier to watch.
267+
cmap: The cmap value passed to matplotlib. e.g. 'gray' for a
268+
grayscale image.
263269
264270
Returns:
265271
The filename of the video if show=False, otherwise it will
@@ -277,12 +283,16 @@ def generate_video(
277283
# Add reversed so we have an animation cycle
278284
towrite = list(reversed(imgs)) + list(imgs)
279285
ys = list(reversed(params["preds"])) + list(params["preds"])
286+
287+
for n in range(extra_loops):
288+
towrite += towrite
289+
ys += ys
280290

281291
for idx, img in enumerate(towrite):
282292

283293
px = 1 / plt.rcParams["figure.dpi"]
284294
full_frame(img[0].shape[0] * px, img[0].shape[1] * px)
285-
plt.imshow(img[0], interpolation="none")
295+
plt.imshow(img[0], interpolation="none", cmap=cmap)
286296

287297
if watermark:
288298
# Write prob output in upper left

0 commit comments

Comments
 (0)