Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codemod pep585_imports #2648

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ax/analysis/old/analysis_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import List, Optional, Tuple
from typing import Optional

import pandas as pd
import plotly.graph_objects as go
Expand All @@ -24,7 +24,7 @@ class AnalysisReport:
set of data from an experiment.
"""

analyses: List[BaseAnalysis] = []
analyses: list[BaseAnalysis] = []
experiment: Experiment

time_started: Optional[int] = None
Expand All @@ -33,7 +33,7 @@ class AnalysisReport:
def __init__(
self,
experiment: Experiment,
analyses: List[BaseAnalysis],
analyses: list[BaseAnalysis],
time_started: Optional[int] = None,
time_completed: Optional[int] = None,
) -> None:
Expand Down Expand Up @@ -61,8 +61,8 @@ def report_completed(self) -> bool:

def run_analysis_report(
self,
) -> List[
Tuple[
) -> list[
tuple[
BaseAnalysis,
pd.DataFrame,
Optional[go.Figure],
Expand Down
16 changes: 8 additions & 8 deletions ax/analysis/old/cross_validation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict

from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Optional

import pandas as pd

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
self,
experiment: Experiment,
model: ModelBridge,
label_dict: Optional[Dict[str, str]] = None,
label_dict: Optional[dict[str, str]] = None,
caption: str = CROSS_VALIDATION_CAPTION,
) -> None:
"""
Expand All @@ -54,13 +54,13 @@ def __init__(
caption: text to display below the plot
"""
self.model = model
self.cv: List[CVResult] = cross_validate(model=model)
self.cv: list[CVResult] = cross_validate(model=model)

self.label_dict: Optional[Dict[str, str]] = label_dict
self.label_dict: Optional[dict[str, str]] = label_dict
if self.label_dict:
self.cv = self.remap_label(cv_results=self.cv, label_dict=self.label_dict)

self.metric_names: List[str] = list(
self.metric_names: list[str] = list(
set().union(*(cv_result.predicted.metric_names for cv_result in self.cv))
)
self.caption = caption
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_df(self) -> pd.DataFrame:
@staticmethod
def compose_annotation(
caption: str, x: float = 0.0, y: float = -0.15
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Composes an annotation dict for use in Plotly figure.
args:
caption: str to use for dropdown text
Expand All @@ -126,8 +126,8 @@ def compose_annotation(

@staticmethod
def remap_label(
cv_results: List[CVResult], label_dict: Dict[str, str]
) -> List[CVResult]:
cv_results: list[CVResult], label_dict: dict[str, str]
) -> list[CVResult]:
"""Remaps labels in cv_results according to label_dict.

Args:
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/old/helpers/color_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@


from numbers import Real
from typing import Tuple

# type aliases
TRGB = Tuple[Real, ...]
TRGB = tuple[Real, ...]


def rgba(rgb_tuple: TRGB, alpha: float = 1) -> str:
Expand Down
20 changes: 10 additions & 10 deletions ax/analysis/old/helpers/cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand All @@ -20,9 +20,9 @@


def error_scatter_data_from_cv_results(
cv_results: List[CVResult],
cv_results: list[CVResult],
metric_name: str,
) -> Tuple[List[float], List[float], List[float], List[float]]:
) -> tuple[list[float], list[float], list[float], list[float]]:
"""Extract mean and error from CVResults

Args:
Expand Down Expand Up @@ -51,7 +51,7 @@ def error_scatter_data_from_cv_results(


def cv_results_to_df(
cv_results: List[CVResult],
cv_results: list[CVResult],
metric_name: str,
) -> pd.DataFrame:
"""Create a dataframe with error scatterplot data
Expand Down Expand Up @@ -95,8 +95,8 @@ def cv_results_to_df(

# Helper functions for plotting model fits
def get_min_max_with_errors(
x: List[float], y: List[float], se_x: List[float], se_y: List[float]
) -> Tuple[float, float]:
x: list[float], y: list[float], se_x: list[float], se_y: list[float]
) -> tuple[float, float]:
"""Get min and max of a bivariate dataset (across variables).

Args:
Expand All @@ -120,8 +120,8 @@ def get_min_max_with_errors(


def get_plotting_limit_ignore_outliers(
x: List[float], y: List[float], se_x: List[float], se_y: List[float]
) -> Tuple[List[float], Tuple[float, float]]:
x: list[float], y: list[float], se_x: list[float], se_y: list[float]
) -> tuple[list[float], tuple[float, float]]:
"""Get a range for a bivarite dataset based on the 25th and 75th percentiles
Used as plotting limit to ignore outliers.

Expand Down Expand Up @@ -166,7 +166,7 @@ def get_plotting_limit_ignore_outliers(
return (layout_range, diagonal_trace_range)


def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str, Any]:
def diagonal_trace(min_: float, max_: float, visible: bool = True) -> dict[str, Any]:
"""Diagonal line trace from (min_, min_) to (max_, max_).

Args:
Expand All @@ -185,7 +185,7 @@ def diagonal_trace(min_: float, max_: float, visible: bool = True) -> Dict[str,
)


def default_value_se_raw(se_raw: Optional[List[float]], out_length: int) -> List[float]:
def default_value_se_raw(se_raw: Optional[list[float]], out_length: int) -> list[float]:
"""
Takes a list of standard errors and maps edge cases to default list
of floats.
Expand Down
10 changes: 5 additions & 5 deletions ax/analysis/old/helpers/layout_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

# pyre-strict

from typing import Any, Dict, List, Tuple, Type
from typing import Any

import plotly.graph_objs as go


def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def updatemenus_format(metric_dropdown: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Formats for use in the cross validation plot
"""
Expand Down Expand Up @@ -58,11 +58,11 @@ def updatemenus_format(metric_dropdown: List[Dict[str, Any]]) -> List[Dict[str,


def layout_format(
layout_axis_range_value: Tuple[float, float],
layout_axis_range_value: tuple[float, float],
xlabel: str,
ylabel: str,
updatemenus: List[Dict[str, Any]],
) -> Type[go.Figure]:
updatemenus: list[dict[str, Any]],
) -> type[go.Figure]:
"""
Constructs a layout object for a CrossValidation figure.
args:
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/old/helpers/plot_data_df_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

from typing import List, Set

import numpy as np
import pandas as pd
Expand All @@ -19,7 +18,7 @@

def get_plot_data_in_sample_arms_df(
model: ModelBridge,
metric_names: Set[str],
metric_names: set[str],
) -> pd.DataFrame:
"""Get in-sample arms from a model with observed and predicted values
for specified metrics.
Expand All @@ -46,7 +45,7 @@ def get_plot_data_in_sample_arms_df(
}
"""
observations = model.get_training_data()
training_in_design: List[bool] = model.training_in_design
training_in_design: list[bool] = model.training_in_design

# Merge multiple measurements within each Observation with IVW to get
# un-modeled prediction
Expand Down
8 changes: 4 additions & 4 deletions ax/analysis/old/helpers/plot_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from logging import Logger
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union

from ax.analysis.old.helpers.constants import DECIMALS, Z

Expand All @@ -23,9 +23,9 @@
logger: Logger = get_logger(__name__)

# Typing alias
RawData = List[Dict[str, Union[str, float]]]
RawData = list[dict[str, Union[str, float]]]

TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]]
TNullableGeneratorRunsDict = Optional[dict[str, GeneratorRun]]


def _format_dict(param_dict: TParameterization, name: str = "Parameterization") -> str:
Expand Down Expand Up @@ -76,7 +76,7 @@ def resize_subtitles(figure: go.Figure, size: int) -> go.Figure:
return figure


def arm_name_to_sort_key(arm_name: str) -> Tuple[str, int, int]:
def arm_name_to_sort_key(arm_name: str) -> tuple[str, int, int]:
"""Parses arm name into tuple suitable for reverse sorting by key

Example:
Expand Down
12 changes: 6 additions & 6 deletions ax/analysis/old/helpers/scatter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numbers

from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -98,7 +98,7 @@ def _relativize_filtered_dataframe(

def extract_mean_and_error_from_df(
df: pd.DataFrame,
) -> Tuple[List[float], List[float], List[float], List[float]]:
) -> tuple[list[float], list[float], list[float], list[float]]:
"""Extract mean and error from dataframe.

Args:
Expand All @@ -119,8 +119,8 @@ def extract_mean_and_error_from_df(

def make_label(
arm_name: str,
x_axis_values: Optional[Tuple[str, float, float]],
y_axis_values: Tuple[str, float, float],
x_axis_values: Optional[tuple[str, float, float]],
y_axis_values: tuple[str, float, float],
param_blob: TParameterization,
rel: bool,
) -> str:
Expand Down Expand Up @@ -179,7 +179,7 @@ def error_dot_plot_trace_from_df(
df: pd.DataFrame,
show_CI: bool = True,
visible: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Creates trace for dot plot with confidence intervals.
Categorizes by arm name.

Expand Down Expand Up @@ -241,7 +241,7 @@ def error_scatter_trace_from_df(
visible: bool = True,
y_axis_label: Optional[str] = None,
x_axis_label: Optional[str] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Plot scatterplot with error bars.

Args:
Expand Down
8 changes: 4 additions & 4 deletions ax/analysis/old/predicted_outcomes_dot_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Any, Dict, Set
from typing import Any

import numpy as np

Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
"""

self.model = model
self.metrics: Set[str] = model.metric_names
self.metrics: set[str] = model.metric_names
if model.status_quo is None or model.status_quo.arm_name is None:
raise UnsupportedPlotError(
"status quo must be specified for PredictedOutcomesDotPlot"
Expand Down Expand Up @@ -82,15 +82,15 @@ def get_fig(
For each metric, we plot the predicted values for each arm along with its CI
These values are relativized with respect to the status quo.
"""
name_order_axes: Dict[str, Dict[str, Any]] = {}
name_order_axes: dict[str, dict[str, Any]] = {}

in_sample_df = self.get_df()
traces = []
metric_dropdown = []

for i, metric in enumerate(self.metrics):
filtered_df = in_sample_df.loc[in_sample_df["metric_name"] == metric]
data_single: Dict[str, Any] = error_dot_plot_trace_from_df(
data_single: dict[str, Any] = error_dot_plot_trace_from_df(
df=filtered_df, show_CI=True, visible=(i == 0)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Any, Dict, Optional
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -149,7 +149,7 @@ def _find_mean_by_arm_name(
return series.item()


def _get_parameter_dimension(series: pd.Series) -> Dict[str, Any]:
def _get_parameter_dimension(series: pd.Series) -> dict[str, Any]:
# For numeric parameters allow Plotly to infer tick attributes. Note: booleans are
# considered numeric, but in this case we want to treat them as categorical.
if pd.api.types.is_numeric_dtype(series) and not pd.api.types.is_bool_dtype(series):
Expand Down
Loading