diff --git a/repos/kaleido/py/kaleido/scopes/plotly.py b/repos/kaleido/py/kaleido/scopes/plotly.py index b8ee4e8f..3896fad8 100644 --- a/repos/kaleido/py/kaleido/scopes/plotly.py +++ b/repos/kaleido/py/kaleido/scopes/plotly.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from kaleido.scopes.base import BaseScope from _plotly_utils.utils import PlotlyJSONEncoder +from plotly.graph_objects import Figure import base64 @@ -65,13 +66,20 @@ def transform(self, figure, format=None, width=None, height=None, scale=None): :return: image bytes """ # TODO: validate args + if isinstance(figure, Figure): + figure = figure.to_dict() - # Apply defaults + # Apply default format and scale format = format if format is not None else self.default_format - width = width if width is not None else self.default_width - height = height if height is not None else self.default_height scale = scale if scale is not None else self.default_scale + # Get figure layout + layout = figure.get("layout", {}) + + # Compute default width / height + width = width or layout.get("width", None) or self.default_width + height = height or layout.get("height", None) or self.default_height + # Normalize format original_format = format format = format.lower() @@ -96,7 +104,7 @@ def transform(self, figure, format=None, width=None, height=None, scale=None): ) # Check for export error, later can customize error messages for plotly Python users - code = response.pop("code", 0) + code = response.get("code", 0) if code != 0: message = response.get("message", None) raise ValueError( @@ -105,7 +113,7 @@ def transform(self, figure, format=None, width=None, height=None, scale=None): ) ) - img = response.pop("result", None).encode("utf-8") + img = response.get("result").encode("utf-8") # Base64 decode binary types if format not in self._text_formats: diff --git a/repos/kaleido/py/tests/plotly/test_plotly.py b/repos/kaleido/py/tests/plotly/test_plotly.py index 73c8bfaf..d406ee5c 100644 --- a/repos/kaleido/py/tests/plotly/test_plotly.py +++ b/repos/kaleido/py/tests/plotly/test_plotly.py @@ -1,13 +1,18 @@ import os -# from os.path import join -import pathlib +import sys from .. import baseline_root, tests_root from kaleido.scopes.plotly import PlotlyScope import pytest from .fixtures import all_figures, all_formats, mapbox_figure, simple_figure +import plotly.graph_objects as go + import plotly.io as pio pio.templates.default = None +if sys.version_info >= (3, 3): + from unittest.mock import Mock +else: + from mock import Mock os.environ['LIBGL_ALWAYS_SOFTWARE'] = '1' os.environ['GALLIUM_DRIVER'] = 'softpipe' @@ -118,3 +123,41 @@ def test_bad_format_file(): local_scope.transform(fig, format='bogus') e.match("Invalid format") + + +def test_figure_size(): + # Create mocked scope + scope = PlotlyScope() + transform_mock = Mock(return_value={"code": 0, "result": "image"}) + scope._perform_transform = transform_mock + + # Set defualt width / height + scope.default_width = 543 + scope.default_height = 567 + scope.default_format = "svg" + scope.default_scale = 2 + + # Make sure default width/height is used when no figure + # width/height specified + transform_mock.reset_mock() + fig = go.Figure() + scope.transform(fig) + transform_mock.assert_called_once_with( + fig.to_dict(), format="svg", scale=2, width=543, height=567 + ) + + # Make sure figure's width/height takes precedence over defaults + transform_mock.reset_mock() + fig = go.Figure().update_layout(width=123, height=234) + scope.transform(fig) + transform_mock.assert_called_once_with( + fig.to_dict(), format="svg", scale=2, width=123, height=234 + ) + + # Make sure kwargs take precedence over Figure layout values + transform_mock.reset_mock() + fig = go.Figure().update_layout(width=123, height=234) + scope.transform(fig, width=987, height=876) + transform_mock.assert_called_once_with( + fig.to_dict(), format="svg", scale=2, width=987, height=876 + )