@@ -247,6 +247,8 @@ def generate_video(
247
247
temp_path : str = "/tmp/gifsplanation" ,
248
248
show : bool = True ,
249
249
verbose : bool = True ,
250
+ extra_loops : int = 0 ,
251
+ cmap : str = None ,
250
252
):
251
253
"""Generate a video from the generated images.
252
254
@@ -260,6 +262,10 @@ def generate_video(
260
262
temp_path: A temp path to write images.
261
263
show: To try and show the video in a jupyter notebook.
262
264
verbose: True to print debug text
265
+ extra_loops: The video does one loop by default. This will repeat
266
+ those loops to make it easier to watch.
267
+ cmap: The cmap value passed to matplotlib. e.g. 'gray' for a
268
+ grayscale image.
263
269
264
270
Returns:
265
271
The filename of the video if show=False, otherwise it will
@@ -277,12 +283,16 @@ def generate_video(
277
283
# Add reversed so we have an animation cycle
278
284
towrite = list (reversed (imgs )) + list (imgs )
279
285
ys = list (reversed (params ["preds" ])) + list (params ["preds" ])
286
+
287
+ for n in range (extra_loops ):
288
+ towrite += towrite
289
+ ys += ys
280
290
281
291
for idx , img in enumerate (towrite ):
282
292
283
293
px = 1 / plt .rcParams ["figure.dpi" ]
284
294
full_frame (img [0 ].shape [0 ] * px , img [0 ].shape [1 ] * px )
285
- plt .imshow (img [0 ], interpolation = "none" )
295
+ plt .imshow (img [0 ], interpolation = "none" , cmap = cmap )
286
296
287
297
if watermark :
288
298
# Write prob output in upper left
0 commit comments