|
3 | 3 | import numpy as np
|
4 | 4 | import pandas as pd
|
5 | 5 | import pytest
|
| 6 | +from packaging import version |
| 7 | +import unittest.mock as mock |
6 | 8 | from plotly.express._core import build_dataframe
|
7 | 9 | from pandas.testing import assert_frame_equal
|
8 | 10 |
|
| 11 | +# Fixtures |
| 12 | +# -------- |
| 13 | +@pytest.fixture |
| 14 | +def add_interchange_module_for_old_pandas(): |
| 15 | + if not hasattr(pd.api, "interchange"): |
| 16 | + pd.api.interchange = mock.MagicMock() |
| 17 | + # to make the following import work: `import pandas.api.interchange` |
| 18 | + with mock.patch.dict( |
| 19 | + "sys.modules", {"pandas.api.interchange": pd.api.interchange} |
| 20 | + ): |
| 21 | + yield |
| 22 | + else: |
| 23 | + yield |
| 24 | + |
9 | 25 |
|
10 | 26 | def test_numpy():
|
11 | 27 | fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9])
|
@@ -233,6 +249,47 @@ def test_build_df_with_index():
|
233 | 249 | assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"])
|
234 | 250 |
|
235 | 251 |
|
| 252 | +def test_build_df_using_interchange_protocol_mock( |
| 253 | + add_interchange_module_for_old_pandas, |
| 254 | +): |
| 255 | + class CustomDataFrame: |
| 256 | + def __dataframe__(self): |
| 257 | + pass |
| 258 | + |
| 259 | + input_dataframe = CustomDataFrame() |
| 260 | + args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length") |
| 261 | + |
| 262 | + iris_pandas = px.data.iris() |
| 263 | + |
| 264 | + with mock.patch("pandas.__version__", "2.0.2"): |
| 265 | + with mock.patch( |
| 266 | + "pandas.api.interchange.from_dataframe", return_value=iris_pandas |
| 267 | + ) as mock_from_dataframe: |
| 268 | + build_dataframe(args, go.Scatter) |
| 269 | + mock_from_dataframe.assert_called_once_with(input_dataframe) |
| 270 | + |
| 271 | + |
| 272 | +@pytest.mark.skipif( |
| 273 | + version.parse(pd.__version__) < version.parse("2.0.2"), |
| 274 | + reason="plotly doesn't use a dataframe interchange protocol for pandas < 2.0.2", |
| 275 | +) |
| 276 | +@pytest.mark.parametrize("test_lib", ["vaex", "polars"]) |
| 277 | +def test_build_df_from_vaex_and_polars(test_lib): |
| 278 | + if test_lib == "vaex": |
| 279 | + import vaex as lib |
| 280 | + else: |
| 281 | + import polars as lib |
| 282 | + |
| 283 | + # take out the 'species' columns since the vaex implementation does not cover strings yet |
| 284 | + iris_pandas = px.data.iris()[["petal_width", "sepal_length"]] |
| 285 | + iris_vaex = lib.from_pandas(iris_pandas) |
| 286 | + args = dict(data_frame=iris_vaex, x="petal_width", y="sepal_length") |
| 287 | + out = build_dataframe(args, go.Scatter) |
| 288 | + assert_frame_equal( |
| 289 | + iris_pandas.reset_index()[out["data_frame"].columns], out["data_frame"] |
| 290 | + ) |
| 291 | + |
| 292 | + |
236 | 293 | def test_timezones():
|
237 | 294 | df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]})
|
238 | 295 | df["date"] = pd.to_datetime(df["date"])
|
|
0 commit comments