Skip to content

Commit

Permalink
SQAAnalysisCard refactor (#2644)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2644

* Refactor `SQAAnalysis` -> `SQAAnalysisCard`
* Delete `SQAAnalysisReport` and related encoders/decoders

Reviewed By: danielcohenlive

Differential Revision: D60258359

fbshipit-source-id: 4b1a96742f365b17c96f22ae9ee3211cf50019d6
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Aug 8, 2024
1 parent e459a08 commit a3debca
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 317 deletions.
36 changes: 1 addition & 35 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -54,7 +50,6 @@
from ax.storage.sqa_store.db import session_scope
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -67,16 +62,10 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from pandas import read_json
from pyre_extensions import assert_is_instance
from sqlalchemy.orm.exc import DetachedInstanceError

Expand Down Expand Up @@ -993,29 +982,6 @@ def data_from_sqa(
dat.db_id = data_sqa.id
return dat

def analysis_from_sqa(
self,
analysis_sqa: SQAAnalysis,
experiment: Experiment,
) -> BaseAnalysis:
"""Convert SQLAlchemy Analysis to Ax Analysis Object."""
# TODO: generalize solution for pd dataframe type casting of "arm_name" column.
if analysis_sqa.experiment_analysis_type == AnalysisType.PLOTLY_VISUALIZATION:
return BasePlotlyVisualization(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
fig_input=pio.from_json(analysis_sqa.fig_json, output_type="Figure"),
)
else:
return BaseAnalysis(
experiment=experiment,
df_input=read_json(
analysis_sqa.dataframe_json, dtype={"arm_name": "str"}
),
)

def _metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
"""Convert SQLAlchemy Metric to Ax Metric"""
if metric_sqa.metric_type not in self.config.reverse_metric_registry:
Expand Down
48 changes: 1 addition & 47 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
from logging import Logger
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import plotly
import plotly.io as pio

from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.old.base_plotly_visualization import BasePlotlyVisualization

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import AbandonedArm, BatchTrial
Expand Down Expand Up @@ -51,7 +45,6 @@
from ax.storage.json_store.encoder import object_to_json
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAArm,
SQAData,
SQAExperiment,
Expand All @@ -64,12 +57,7 @@
SQATrial,
)
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.utils import (
AnalysisType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DomainType, MetricIntent, ParameterConstraintType
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -1063,37 +1051,3 @@ def data_to_sqa(
)
),
)

def analysis_to_sqa(
self,
analysis: BaseAnalysis,
) -> SQAAnalysis:
"""Convert Ax analysis to SQLAlchemy."""
# pyre-fixme: Expected `Base` for 1st...ot `typing.Type[BaseAnalysis]`.
analysis_class: SQAAnalysis = self.config.class_to_sqa_class[BaseAnalysis]

is_plotly_visualization: bool = isinstance(analysis, BasePlotlyVisualization)

# pyre-fixme[29]: `SQAAnalysis` is not a function.
return analysis_class(
id=-1,
analysis_class_name=type(analysis).__name__,
time_analysis_start=-1,
time_analysis_completed=-1,
experiment_analysis_type=(
AnalysisType.PLOTLY_VISUALIZATION
if is_plotly_visualization
else AnalysisType.ANALYSIS
),
dataframe_json=analysis.df.to_json(),
fig_json=(
None
if not is_plotly_visualization
else pio.to_json(
checked_cast(BasePlotlyVisualization, analysis).fig,
validate=True,
remove_uids=False,
)
),
plotly_version=plotly.__version__,
)
89 changes: 30 additions & 59 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from ax.analysis.analysis import AnalysisCardLevel

from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import LifecycleStage
from ax.core.parameter import ParameterType
Expand All @@ -34,13 +36,7 @@
)
from ax.storage.sqa_store.sqa_enum import IntEnum, StringEnum
from ax.storage.sqa_store.timestamp import IntTimestamp
from ax.storage.utils import (
AnalysisType,
DataType,
DomainType,
MetricIntent,
ParameterConstraintType,
)
from ax.storage.utils import DataType, DomainType, MetricIntent, ParameterConstraintType
from sqlalchemy import (
BigInteger,
Boolean,
Expand Down Expand Up @@ -462,6 +458,31 @@ class SQATrial(Base):
)


class SQAAnalysisCard(Base):
__tablename__: str = "analysis_card"

# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
id: int = Column(Integer, primary_key=True)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
title: str = Column(String(LONG_STRING_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
subtitle: str = Column(Text, nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
level: AnalysisCardLevel = Column(IntEnum(AnalysisCardLevel), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
blob: str = Column(Text(LONGTEXT_BYTES), nullable=False)
# pyre-fixme[8]: Attribute has type `str` but is used as type `Column[str]`.
blob_annotation: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
time_created: datetime = Column(IntTimestamp, nullable=False)
# pyre-fixme[8]: Attribute has type `int` but is used as type `Column[int]`.
experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id"), nullable=False)


class SQAExperiment(Base):
__tablename__: str = "experiment_v2"

Expand Down Expand Up @@ -520,56 +541,6 @@ class SQAExperiment(Base):
uselist=False,
lazy=True,
)


class SQAAnalysis(Base):
__tablename__: str = "analysis"

# pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`.
id: int = Column(Integer, primary_key=True)

# pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`.
analysis_class_name: str = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_analysis_start: datetime = Column(IntTimestamp, nullable=False)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_analysis_completed: datetime = Column(IntTimestamp, nullable=False)

# pyre-fixme[8]: Attribute has type `AnalysisType`; used as
# `Column[typing.Any]`.
experiment_analysis_type: AnalysisType = Column(
StringEnum(AnalysisType), nullable=False
)

# pyre-fixme[8]: Attribute has type `str`; used as `Column[str]`.
dataframe_json: str = Column(Text(LONGTEXT_BYTES), nullable=False)

# pyre-fixme[8]: Attribute has type `Optional[str]`; used as
# `Column[Optional[str]]`.
fig_json: Optional[str] = Column(Text(LONGTEXT_BYTES), nullable=True)
# pyre-fixme[8]: Attribute has type `Optional[str]`; used as
# `Column[Optional[str]]`.
plotly_version: Optional[str] = Column(
String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True
)

# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
experiment_id: int = Column(Integer)
# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
analysis_report_id: int = Column(Integer, ForeignKey("analysis_report_v2.id"))


class SQAAnalysisReport(Base):
__tablename__: str = "analysis_report_v2"

# pyre-fixme[8]: Attribute has type `int`; used as `Column[int]`.
id: int = Column(Integer, primary_key=True)
# pyre-fixme[8]: Attribute has type `datetime`; used as `Column[typing.Any]`.
time_report_start: datetime = Column(IntTimestamp, nullable=False)

analyses: Optional[List[SQAAnalysis]] = relationship(
"SQAAnalysis", cascade="all, delete-orphan", lazy="selectin"
analysis_cards: List[SQAAnalysisCard] = relationship(
"SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin"
)

# pyre-fixme[8]: Attribute has type `int`; used as `Column[Optional[int]]`.
experiment_id: int = Column(Integer, ForeignKey("experiment_v2.id"))
9 changes: 3 additions & 6 deletions ax/storage/sqa_store/sqa_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type, Union

from ax.analysis.old.analysis_report import AnalysisReport
from ax.analysis.old.base_analysis import BaseAnalysis
from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.batch_trial import AbandonedArm
Expand All @@ -36,8 +35,7 @@
from ax.storage.sqa_store.db import SQABase
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisReport,
SQAAnalysisCard,
SQAArm,
SQAData,
SQAExperiment,
Expand Down Expand Up @@ -80,8 +78,7 @@ def _default_class_to_sqa_class(self=None) -> Dict[Type[Base], Type[SQABase]]:
Metric: SQAMetric,
Runner: SQARunner,
Trial: SQATrial,
BaseAnalysis: SQAAnalysis,
AnalysisReport: SQAAnalysisReport,
AnalysisCard: SQAAnalysisCard,
}

class_to_sqa_class: Dict[Type[Base], Type[SQABase]] = field(
Expand Down
24 changes: 0 additions & 24 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@
update_runner_on_experiment,
)
from ax.storage.sqa_store.sqa_classes import (
AnalysisType,
SQAAbandonedArm,
SQAAnalysis,
SQAAnalysisReport,
SQAArm,
SQAExperiment,
SQAGeneratorRun,
Expand Down Expand Up @@ -1929,24 +1926,3 @@ def test_CreateAllTablesException(self) -> None:
engine.dialect.default_schema_name = "ax"
with self.assertRaises(ValueError):
create_all_tables(engine)

def test_CreateAnalysisRecords(self) -> None:

sqa_analysis = SQAAnalysis(
analysis_class_name="CrossValidationPlot",
experiment_analysis_type=AnalysisType.PLOTLY_VISUALIZATION,
time_analysis_start=datetime.now(),
time_analysis_completed=datetime.now(),
dataframe_json="none",
)
with session_scope() as session:
_ = session.merge(sqa_analysis)
session.flush()

def test_CreateAnalysisReport(self) -> None:
sqa_analysis_report = SQAAnalysisReport(
time_report_start=datetime.now(),
)
with session_scope() as session:
_ = session.merge(sqa_analysis_report)
session.flush()
Loading

0 comments on commit a3debca

Please sign in to comment.