Skip to content

Commit

Permalink
Tests which store and load dataframe and figure (#2344)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2344

This diff adds tests for CrossValidationPlot
- Stores a json representation of the dataframe of the plot to a tempfile, then reads it back, and asserts equality
- Converts the plot to a json object, converts it back, then checks equality.

We'll need to store analysis objects, so we need to check that the dataframe and figures are serializable to json

Reviewed By: mpolson64

Differential Revision: D55967859

fbshipit-source-id: a97b5430e2e56078735a18402d2b41b68fbcb300
  • Loading branch information
mgrange1998 authored and facebook-github-bot committed Apr 10, 2024
1 parent 99615d6 commit a3661fc
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions ax/analysis/helpers/tests/test_cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import tempfile

import plotly.graph_objects as go

import plotly.io as pio

from ax.analysis.cross_validation_plot import CrossValidationPlot

from ax.analysis.helpers.constants import Z
Expand All @@ -16,6 +20,9 @@
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import fast_botorch_optimize
from pandas import read_json

from pandas.testing import assert_frame_equal


class TestCrossValidationHelpers(TestCase):
Expand Down Expand Up @@ -63,3 +70,26 @@ def test_obs_vs_pred_dropdown_plot(self) -> None:
fig = cross_validation_plot.get_fig()

self.assertIsInstance(fig, go.Figure)

def test_store_df_to_file(self) -> None:
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f:
cross_validation_plot = CrossValidationPlot(
experiment=self.exp, model=self.model
)
cv_df = cross_validation_plot.get_df()
cv_df.to_json(f.name)

loaded_dataframe = read_json(f.name, dtype={"arm_name": "str"})

assert_frame_equal(cv_df, loaded_dataframe, check_dtype=False)

def test_store_plot_as_dict(self) -> None:
cross_validation_plot = CrossValidationPlot(
experiment=self.exp, model=self.model
)
cv_fig = cross_validation_plot.get_fig()

json_obj = pio.to_json(cv_fig, validate=True, remove_uids=False)

loaded_json_obj = pio.from_json(json_obj, output_type="Figure")
self.assertEqual(cv_fig, loaded_json_obj)

0 comments on commit a3661fc

Please sign in to comment.