Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move array_to_imagestr function to be part of public API #2879

Merged
merged 3 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions packages/python/plotly/_plotly_utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from io import BytesIO
import base64
from .png import Writer, from_array

try:
from PIL import Image

pil_imported = True
except ImportError:
pil_imported = False


def image_array_to_data_uri(img, backend="pil", compression=4, ext="png"):
"""Converts a numpy array of uint8 into a base64 png or jpg string.

Parameters
----------
img: ndarray of uint8
array image
backend: str
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
otherwise pypng.
compression: int, between 0 and 9
compression level to be passed to the backend
ext: str, 'png' or 'jpg'
compression format used to generate b64 string
"""
# PIL and pypng error messages are quite obscure so we catch invalid compression values
if compression < 0 or compression > 9:
raise ValueError("compression level must be between 0 and 9.")
alpha = False
if img.ndim == 2:
mode = "L"
elif img.ndim == 3 and img.shape[-1] == 3:
mode = "RGB"
elif img.ndim == 3 and img.shape[-1] == 4:
mode = "RGBA"
alpha = True
else:
raise ValueError("Invalid image shape")
if backend == "auto":
backend = "pil" if pil_imported else "pypng"
if ext != "png" and backend != "pil":
raise ValueError("jpg binary strings are only available with PIL backend")

if backend == "pypng":
ndim = img.ndim
sh = img.shape
if ndim == 3:
img = img.reshape((sh[0], sh[1] * sh[2]))
w = Writer(
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
)
img_png = from_array(img, mode=mode)
prefix = "data:image/png;base64,"
with BytesIO() as stream:
w.write(stream, img_png.rows)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
else: # pil
if not pil_imported:
raise ImportError(
"pillow needs to be installed to use `backend='pil'. Please"
"install pillow or use `backend='pypng'."
)
pil_img = Image.fromarray(img)
if ext == "jpg" or ext == "jpeg":
prefix = "data:image/jpeg;base64,"
ext = "jpeg"
else:
prefix = "data:image/png;base64,"
ext = "png"
with BytesIO() as stream:
pil_img.save(stream, format=ext, compress_level=compression)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
return base64_string
77 changes: 2 additions & 75 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,21 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade
from io import BytesIO
import base64
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
from .png import Writer, from_array
import numpy as np
from plotly.utils import image_array_to_data_uri

try:
import xarray

xarray_imported = True
except ImportError:
xarray_imported = False
try:
from PIL import Image

pil_imported = True
except ImportError:
pil_imported = False

_float_types = []


def _array_to_b64str(img, backend="pil", compression=4, ext="png"):
"""Converts a numpy array of uint8 into a base64 png string.

Parameters
----------
img: ndarray of uint8
array image
backend: str
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
otherwise pypng.
compression: int, between 0 and 9
compression level to be passed to the backend
ext: str, 'png' or 'jpg'
compression format used to generate b64 string
"""
# PIL and pypng error messages are quite obscure so we catch invalid compression values
if compression < 0 or compression > 9:
raise ValueError("compression level must be between 0 and 9.")
alpha = False
if img.ndim == 2:
mode = "L"
elif img.ndim == 3 and img.shape[-1] == 3:
mode = "RGB"
elif img.ndim == 3 and img.shape[-1] == 4:
mode = "RGBA"
alpha = True
else:
raise ValueError("Invalid image shape")
if backend == "auto":
backend = "pil" if pil_imported else "pypng"
if ext != "png" and backend != "pil":
raise ValueError("jpg binary strings are only available with PIL backend")

if backend == "pypng":
ndim = img.ndim
sh = img.shape
if ndim == 3:
img = img.reshape((sh[0], sh[1] * sh[2]))
w = Writer(
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
)
img_png = from_array(img, mode=mode)
prefix = "data:image/png;base64,"
with BytesIO() as stream:
w.write(stream, img_png.rows)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
else: # pil
if not pil_imported:
raise ImportError(
"pillow needs to be installed to use `backend='pil'. Please"
"install pillow or use `backend='pypng'."
)
pil_img = Image.fromarray(img)
if ext == "jpg" or ext == "jpeg":
prefix = "data:image/jpeg;base64,"
ext = "jpeg"
else:
prefix = "data:image/png;base64,"
ext = "png"
with BytesIO() as stream:
pil_img.save(stream, format=ext, compress_level=compression)
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
return base64_string


def _vectorize_zvalue(z, mode="max"):
alpha = 255 if mode == "max" else 0
if z is None:
Expand Down Expand Up @@ -422,7 +349,7 @@ def imshow(
for ch in range(img.shape[-1])
]
)
img_str = _array_to_b64str(
img_str = image_array_to_data_uri(
img_rescaled,
backend=binary_backend,
compression=binary_compression_level,
Expand Down
2 changes: 1 addition & 1 deletion packages/python/plotly/plotly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pprint import PrettyPrinter

from _plotly_utils.utils import *

from _plotly_utils.data_utils import *

# Pretty printing
def _list_repr_elided(v, threshold=200, edgeitems=3, indent=0, width=80):
Expand Down