Skip to content

Commit

Permalink
Merge pull request #17 from senyai/senyai-optimize
Browse files Browse the repository at this point in the history
Add `optimize` argument
  • Loading branch information
maxhumber authored Mar 9, 2023
2 parents d41ef3a + a3e1d08 commit 3c088e3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/spiral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@gif.frame
def plot_spiral(i):
fig = plt.figure(figsize=(5, 3), dpi=100)
ax = fig.gca(projection="3d")
ax = fig.add_subplot(projection="3d")
a, b = 0.5, 0.2
th = np.linspace(475, 500, N)
x = a * np.exp(b * th) * np.cos(th)
Expand Down
30 changes: 29 additions & 1 deletion gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,31 @@ def inner(*args, **kwargs) -> Frame: # type: ignore[valid-type]
return inner


def _optimize_frames(frames: List[Frame]):
import numpy as np
joined_img = PI.fromarray(np.vstack(frames))
joined_img = joined_img.quantize(colors=256 - 1, dither=0)
palette = b'\xff\x00\xff' + joined_img.palette.getdata()[1]
joined_img_arr = np.array(joined_img)
joined_img_arr += 1
arrays = np.vsplit(joined_img_arr, len(frames))

prev_array = arrays[0]
for array in arrays[1:]:
mask = array == prev_array
prev_array = array.copy()
array[mask] = 0
frames_out = [PI.fromarray(array) for array in arrays]
return frames_out, palette


def save(
frames: List[Frame], # type: ignore[valid-type]
path: str,
duration: Milliseconds = 100,
*,
overlapping: bool = True,
optimize: bool = False,
) -> None:
"""Save prepared frames to .gif file
Expand All @@ -78,7 +97,15 @@ def save(
"""

if not path.endswith(".gif"):
raise ValueError("must end with .gif")
raise ValueError(f"'{path}' must end with .gif")

kwargs = {}
if optimize:
frames, palette = _optimize_frames(frames)
kwargs = {
"palette": palette,
"transparency": 0,
}

frames[0].save( # type: ignore
path,
Expand All @@ -88,4 +115,5 @@ def save(
duration=duration,
disposal=0 if overlapping else 2,
loop=0,
**kwargs
)
40 changes: 26 additions & 14 deletions tests/test_gif.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest
from matplotlib import pyplot as plt
from PIL import Image
Expand All @@ -22,30 +23,35 @@ def plot(x, y):
plt.scatter(x, y)


@pytest.fixture(scope="session")
def default_file(tmpdir_factory):
def make_gif(tmpdir_factory, filename, dpi=None, **kwargs):
if dpi is not None:
gif.options.matplotlib["dpi"] = 300
frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])]
path = str(tmpdir_factory.mktemp("matplotlib").join("default.gif"))
gif.save(frames, path)
if dpi is not None:
gif.options.reset()
path = str(tmpdir_factory.mktemp("matplotlib").join(filename))
gif.save(frames, path, **kwargs)
return path


@pytest.fixture(scope="session")
def default_file(tmpdir_factory):
return make_gif(tmpdir_factory, "default.gif")


@pytest.fixture(scope="session")
def optimized_file(tmpdir_factory):
return make_gif(tmpdir_factory, "optimized.gif", optimize=True)


@pytest.fixture(scope="session")
def hd_file(tmpdir_factory):
gif.options.matplotlib["dpi"] = 300
frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])]
gif.options.reset()
path = str(tmpdir_factory.mktemp("matplotlib").join("hd.gif"))
gif.save(frames, path)
return path
return make_gif(tmpdir_factory, "hd.gif", dpi=300)


@pytest.fixture(scope="session")
def long_file(tmpdir_factory):
frames = [plot([0, 5], [0, 5]), plot([0, 10], [0, 10])]
path = str(tmpdir_factory.mktemp("matplotlib").join("long.gif"))
gif.save(frames, path, duration=2500)
return path
return make_gif(tmpdir_factory, "long.gif", duration=2500)


def test_frame():
Expand All @@ -59,6 +65,12 @@ def test_default_save(default_file):
assert milliseconds(img) == 200


def test_optimization(default_file, optimized_file):
default_size = os.stat(default_file).st_size
optimized_size = os.stat(optimized_file).st_size
assert optimized_size < default_size * 0.9


def test_dpi_save(hd_file):
img = Image.open(hd_file)
assert img.format == "GIF"
Expand Down

0 comments on commit 3c088e3

Please sign in to comment.