diff --git a/examples/spiral.py b/examples/spiral.py index 98f9431..ba5b5c3 100644 --- a/examples/spiral.py +++ b/examples/spiral.py @@ -10,7 +10,7 @@ @gif.frame def plot_spiral(i): fig = plt.figure(figsize=(5, 3), dpi=100) - ax = fig.gca(projection="3d") + ax = fig.add_subplot(projection="3d") a, b = 0.5, 0.2 th = np.linspace(475, 500, N) x = a * np.exp(b * th) * np.cos(th) diff --git a/gif.py b/gif.py index 570e4e1..efd90f1 100644 --- a/gif.py +++ b/gif.py @@ -60,12 +60,31 @@ def inner(*args, **kwargs) -> Frame: # type: ignore[valid-type] return inner +def _optimize_frames(frames: List[Frame]): + import numpy as np + joined_img = PI.fromarray(np.vstack(frames)) + joined_img = joined_img.quantize(colors=256 - 1, dither=0) + palette = b'\xff\x00\xff' + joined_img.palette.getdata()[1] + joined_img_arr = np.array(joined_img) + joined_img_arr += 1 + arrays = np.vsplit(joined_img_arr, len(frames)) + + prev_array = arrays[0] + for array in arrays[1:]: + mask = array == prev_array + prev_array = array.copy() + array[mask] = 0 + frames_out = [PI.fromarray(array) for array in arrays] + return frames_out, palette + + def save( frames: List[Frame], # type: ignore[valid-type] path: str, duration: Milliseconds = 100, *, overlapping: bool = True, + optimize: bool = False, ) -> None: """Save prepared frames to .gif file @@ -78,7 +97,15 @@ def save( """ if not path.endswith(".gif"): - raise ValueError("must end with .gif") + raise ValueError(f"'{path}' must end with .gif") + + kwargs = {} + if optimize: + frames, palette = _optimize_frames(frames) + kwargs = { + "palette": palette, + "transparency": 0, + } frames[0].save( # type: ignore path, @@ -88,4 +115,5 @@ def save( duration=duration, disposal=0 if overlapping else 2, loop=0, + **kwargs ) diff --git a/tests/test_gif.py b/tests/test_gif.py index 87f0ebe..5d260a7 100644 --- a/tests/test_gif.py +++ b/tests/test_gif.py @@ -1,3 +1,4 @@ +import os import pytest from matplotlib import pyplot as plt from PIL import Image @@ -22,30 +23,35 @@ def plot(x, y): plt.scatter(x, y) -@pytest.fixture(scope="session") -def default_file(tmpdir_factory): +def make_gif(tmpdir_factory, filename, dpi=None, **kwargs): + if dpi is not None: + gif.options.matplotlib["dpi"] = 300 frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])] - path = str(tmpdir_factory.mktemp("matplotlib").join("default.gif")) - gif.save(frames, path) + if dpi is not None: + gif.options.reset() + path = str(tmpdir_factory.mktemp("matplotlib").join(filename)) + gif.save(frames, path, **kwargs) return path +@pytest.fixture(scope="session") +def default_file(tmpdir_factory): + return make_gif(tmpdir_factory, "default.gif") + + +@pytest.fixture(scope="session") +def optimized_file(tmpdir_factory): + return make_gif(tmpdir_factory, "optimized.gif", optimize=True) + + @pytest.fixture(scope="session") def hd_file(tmpdir_factory): - gif.options.matplotlib["dpi"] = 300 - frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])] - gif.options.reset() - path = str(tmpdir_factory.mktemp("matplotlib").join("hd.gif")) - gif.save(frames, path) - return path + return make_gif(tmpdir_factory, "hd.gif", dpi=300) @pytest.fixture(scope="session") def long_file(tmpdir_factory): - frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])] - path = str(tmpdir_factory.mktemp("matplotlib").join("long.gif")) - gif.save(frames, path, duration=2500) - return path + return make_gif(tmpdir_factory, "long.gif", duration=2500) def test_frame(): @@ -59,6 +65,12 @@ def test_default_save(default_file): assert milliseconds(img) == 200 +def test_optimization(default_file, optimized_file): + default_size = os.stat(default_file).st_size + optimized_size = os.stat(optimized_file).st_size + assert optimized_size < default_size * 0.9 + + def test_dpi_save(hd_file): img = Image.open(hd_file) assert img.format == "GIF"