Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnalysisCard refactor #2589

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ax/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.analysis.markdown import * # noqa
from ax.analysis.plotly import * # noqa

__all__ = ["Analysis", "AnalysisCard", "AnalysisCardLevel"]
87 changes: 87 additions & 0 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Any, Optional, Protocol

import pandas as pd
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy


class AnalysisCardLevel(Enum):
DEBUG = 0
LOW = 1
MID = 2
HIGH = 3
CRITICAL = 4


class AnalysisCard:
# Name of the analysis computed, usually the class name of the Analysis which
# produced the card. Useful for grouping by when querying a large collection of
# cards.
name: str

title: str
subtitle: str
level: AnalysisCardLevel

df: pd.DataFrame # Raw data produced by the Analysis

# pyre-ignore[4] We explicitly want to allow any type here, blob is narrowed in
# AnalysisCard's subclasses
blob: Any # Data processed and ready for end-user consumption

# How to interpret the blob (ex. "dataframe", "plotly", "markdown")
blob_annotation = "dataframe"

def __init__(
self,
name: str,
title: str,
subtitle: str,
level: AnalysisCardLevel,
df: pd.DataFrame,
# pyre-ignore[2] We explicitly want to allow any type here, blob is narrowed in
# AnalysisCard's subclasses
blob: Any,
) -> None:
self.name = name
self.title = title
self.subtitle = subtitle
self.level = level
self.df = df
self.blob = blob


class Analysis(Protocol):
"""
An Analysis is a class that given either and Experiment, a GenerationStrategy, or
both can compute some data intended for end-user consumption. The data is returned
to the user in the form of an AnalysisCard which contains the raw data, a blob (the
data processed for end-user consumption), and miscellaneous metadata that can be
useful for rendering the card or a collection of cards.

The AnalysisCard is a thin wrapper around the raw data and the processed blob;
Analyses impose structure on their blob should subclass Analysis. See
PlotlyAnalysis for an example which produces cards where the blob is always a
Plotly Figure object.

A good pattern to follow when implementing your own Analyses is to configure
"settings" (like which parameter or metrics to operate on, or whether to use
observed or modeled effects) in your Analyses' __init__ methods, then to consume
these settings in the compute method.
"""

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> AnalysisCard:
# Note: when implementing compute always prefer experiment.lookup_data() to
# experiment.fetch_data() to avoid unintential data fetching within the report
# generation.
...
11 changes: 11 additions & 0 deletions ax/analysis/markdown/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.analysis.markdown.markdown_analysis import (
MarkdownAnalysis,
MarkdownAnalysisCard,
)

__all__ = ["MarkdownAnalysis", "MarkdownAnalysisCard"]
38 changes: 38 additions & 0 deletions ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy


class MarkdownAnalysisCard(AnalysisCard):
name: str

title: str
subtitle: str
level: AnalysisCardLevel

df: pd.DataFrame
blob: str
blob_annotation = "markdown"

def get_markdown(self) -> str:
return self.blob


class MarkdownAnalysis(Analysis):
"""
An Analysis that computes a paragraph of Markdown formatted text.
"""

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> MarkdownAnalysisCard: ...
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pandas as pd
import plotly.graph_objects as go

from ax.analysis.base_analysis import BaseAnalysis
from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.core.experiment import Experiment

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import plotly.graph_objects as go

from ax.analysis.base_analysis import BaseAnalysis
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.core.experiment import Experiment


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@

import pandas as pd

from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.analysis.helpers.cross_validation_helpers import (
from ax.analysis.old.helpers.cross_validation_helpers import (
cv_results_to_df,
diagonal_trace,
get_plotting_limit_ignore_outliers,
)

from ax.analysis.helpers.layout_helpers import layout_format, updatemenus_format
from ax.analysis.old.helpers.layout_helpers import layout_format, updatemenus_format

from ax.analysis.helpers.scatter_helpers import (
from ax.analysis.old.helpers.scatter_helpers import (
error_scatter_trace_from_df,
extract_mean_and_error_from_df,
)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import pandas as pd
import plotly.graph_objs as go

from ax.analysis.helpers.constants import Z
from ax.analysis.old.helpers.constants import Z

from ax.analysis.helpers.plot_helpers import arm_name_to_sort_key
from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key

from ax.modelbridge.cross_validation import CVResult

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from logging import Logger
from typing import Dict, List, Optional, Tuple, Union

from ax.analysis.helpers.constants import DECIMALS, Z
from ax.analysis.old.helpers.constants import DECIMALS, Z

from ax.core.generator_run import GeneratorRun

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import pandas as pd

import plotly.graph_objs as go
from ax.analysis.helpers.color_helpers import rgba
from ax.analysis.old.helpers.color_helpers import rgba

from ax.analysis.helpers.constants import CI_OPACITY, COLORS, DECIMALS, Z
from ax.analysis.old.helpers.constants import CI_OPACITY, COLORS, DECIMALS, Z

from ax.analysis.helpers.plot_helpers import _format_CI, _format_dict
from ax.analysis.old.helpers.plot_helpers import _format_CI, _format_dict

from ax.core.types import TParameterization

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import plotly.graph_objects as go
import plotly.io as pio
from ax.analysis.cross_validation_plot import CrossValidationPlot
from ax.analysis.helpers.constants import Z
from ax.analysis.helpers.cross_validation_helpers import get_min_max_with_errors
from ax.analysis.old.cross_validation_plot import CrossValidationPlot
from ax.analysis.old.helpers.constants import Z
from ax.analysis.old.helpers.cross_validation_helpers import get_min_max_with_errors
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import copy

import plotly.graph_objects as go
from ax.analysis.cross_validation_plot import CrossValidationPlot
from ax.analysis.helpers.cross_validation_helpers import (
from ax.analysis.old.cross_validation_plot import CrossValidationPlot
from ax.analysis.old.helpers.cross_validation_helpers import (
error_scatter_data_from_cv_results,
)
from ax.analysis.helpers.scatter_helpers import error_scatter_trace_from_df
from ax.analysis.old.helpers.scatter_helpers import error_scatter_trace_from_df
from ax.modelbridge.cross_validation import cross_validate
from ax.modelbridge.registry import Models
from ax.plot.base import PlotMetric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pandas as pd

from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

import pandas as pd

from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.analysis.helpers.layout_helpers import updatemenus_format
from ax.analysis.old.helpers.layout_helpers import updatemenus_format

from ax.analysis.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df
from ax.analysis.helpers.plot_helpers import arm_name_to_sort_key, resize_subtitles
from ax.analysis.old.helpers.plot_data_df_helpers import get_plot_data_in_sample_arms_df
from ax.analysis.old.helpers.plot_helpers import arm_name_to_sort_key, resize_subtitles

from ax.analysis.helpers.scatter_helpers import (
from ax.analysis.old.helpers.scatter_helpers import (
error_dot_plot_trace_from_df,
relativize_dataframe,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pandas as pd
import plotly.graph_objects as go

from ax.analysis.analysis_report import AnalysisReport
from ax.analysis.old.analysis_report import AnalysisReport

from ax.analysis.base_analysis import BaseAnalysis
from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import plotly.graph_objects as go

from ax.analysis.base_analysis import BaseAnalysis
from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.modelbridge.registry import Models

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict

import plotly.graph_objects as go
from ax.analysis.cross_validation_plot import CrossValidationPlot
from ax.analysis.old.cross_validation_plot import CrossValidationPlot
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict


from ax.analysis.parallel_coordinates_plot import ParallelCoordinatesPlot
from ax.analysis.old.parallel_coordinates_plot import ParallelCoordinatesPlot
from ax.core.batch_trial import BatchTrial

from ax.utils.common.testutils import TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import unittest

import plotly.graph_objects as go
from ax.analysis.predicted_outcomes_dot_plot import PredictedOutcomesDotPlot
from ax.analysis.old.predicted_outcomes_dot_plot import PredictedOutcomesDotPlot
from ax.exceptions.core import UnsupportedPlotError
from ax.modelbridge.registry import Models
from ax.utils.testing.core_stubs import get_branin_experiment
Expand Down
8 changes: 8 additions & 0 deletions ax/analysis/plotly/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard

__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard"]
39 changes: 39 additions & 0 deletions ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go


class PlotlyAnalysisCard(AnalysisCard):
name: str

title: str
subtitle: str
level: AnalysisCardLevel

df: pd.DataFrame
blob: go.Figure
blob_annotation = "plotly"

def get_figure(self) -> go.Figure:
return self.blob


class PlotlyAnalysis(Analysis):
"""
An Analysis that computes a Plotly figure.
"""

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> PlotlyAnalysisCard: ...
4 changes: 2 additions & 2 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import pandas as pd
import plotly.io as pio

from ax.analysis.base_analysis import BaseAnalysis
from ax.analysis.base_plotly_visualization import BasePlotlyVisualization
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down
Loading