Skip to content

Commit

Permalink
Merge pull request #8 from bigladder/update-layout-md-plots
Browse files Browse the repository at this point in the history
Update layout md plots
  • Loading branch information
nealkruis authored May 10, 2024
2 parents 4b4109f + 976e9a7 commit 6e62beb
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 30 deletions.
88 changes: 75 additions & 13 deletions dimes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
from dataclasses import dataclass
import warnings
from datetime import datetime
import math
import bisect

from plotly.graph_objects import Figure, Scatter # type: ignore

import koozie

WHITE = "white"
BLACK = "black"
GREY = "rgba(128,128,128,0.3)"


@dataclass
class LineProperties:
color: Union[str, None] = None
line_type: Union[str, None] = None
line_width: Union[int, None] = None
line_width: Union[SupportsFloat, None] = None
marker_symbol: Union[str, None] = None
marker_size: Union[int, None] = None
marker_size: Union[SupportsFloat, None] = None
marker_line_color: Union[str, None] = None
marker_line_width: Union[int, None] = None
marker_line_width: Union[SupportsFloat, None] = None
marker_fill_color: Union[str, None] = None
is_visible: bool = True

def get_line_mode(self):

Expand Down Expand Up @@ -49,12 +54,12 @@ def get_line_mode(self):

@dataclass
class MarkersOnly(LineProperties):
line_width: Union[int, None] = 0
line_width: Union[SupportsFloat, None] = 0


@dataclass
class LinesOnly(LineProperties):
marker_size: Union[int, None] = 0
marker_size: Union[SupportsFloat, None] = 0


class DimensionalData:
Expand Down Expand Up @@ -105,7 +110,9 @@ def __init__(
display_units: Union[str, None] = None,
line_properties: LineProperties = LineProperties(),
is_visible: bool = True,
legend_group: Union[str, None] = None,
x_axis: Union[DimensionalData, TimeSeriesAxis, List[SupportsFloat], List[datetime], None] = None,
y_axis_min: Union[SupportsFloat, None] = 0.0,
):
super().__init__(data_values, name, native_units, display_units)
self.x_axis: Union[DimensionalData, TimeSeriesAxis, None]
Expand All @@ -116,8 +123,10 @@ def __init__(
self.x_axis = DimensionalData(x_axis) # type: ignore[arg-type]
else:
self.x_axis = x_axis
self.y_axis_min = y_axis_min
self.line_properties = line_properties
self.is_visible = is_visible
self.legend_group = legend_group


class DimensionalAxis:
Expand All @@ -128,11 +137,29 @@ def __init__(self, display_data: DisplayData, name: Union[str, None]) -> None:
self.units = display_data.display_units
self.dimensionality = display_data.dimensionality
self.display_data_set: List[DisplayData] = [display_data]
self.range_min: SupportsFloat = float("inf")
self.range_max: SupportsFloat = -float("inf")

def get_axis_label(self) -> str:
"""Make the string that appears as the axis label"""
return f"{self.name} [{self.units}]"

@staticmethod
def get_axis_range(value_min, value_max):
max_ticks = 6
tick_scale_options = [1, 2, 5, 10]

value_range = value_max - value_min
min_tick_size = value_range / max_ticks
magnitude = 10 ** math.floor(math.log(min_tick_size, 10))
residual = min_tick_size / magnitude
tick_size = (
tick_scale_options[bisect.bisect_right(tick_scale_options, residual)] if residual < 10 else 10
) * magnitude
range_min = math.floor(value_min / tick_size) * tick_size
range_max = math.ceil(value_max / tick_size) * tick_size
return [range_min, range_max]


class DimensionalSubplot:
"""Dimensional subplot. May contain multiple `DimensionalAxis` objects."""
Expand Down Expand Up @@ -169,7 +196,11 @@ def add_display_data_to_existing_axis(self, axis_data: DisplayData, axis: Dimens
class DimensionalPlot:
"""Plot of dimensional data."""

def __init__(self, x_axis: Union[DimensionalData, TimeSeriesAxis, List[SupportsFloat], List[datetime]]):
def __init__(
self,
x_axis: Union[DimensionalData, TimeSeriesAxis, List[SupportsFloat], List[datetime]],
title: Union[str, None] = None,
):
self.figure = Figure()
self.x_axis: Union[DimensionalData, TimeSeriesAxis]
if isinstance(x_axis, list):
Expand All @@ -181,6 +212,7 @@ def __init__(self, x_axis: Union[DimensionalData, TimeSeriesAxis, List[SupportsF
self.x_axis = x_axis
self.subplots: List[Union[DimensionalSubplot, None]] = [None]
self.is_finalized = False
self.figure.layout["title"] = title

def add_display_data(
self,
Expand All @@ -204,10 +236,20 @@ def add_display_data(
def finalize_plot(self):
"""Once all DisplayData objects have been added, generate plot and subplots."""
if not self.is_finalized:
grid_line_width = 1.5
at_least_one_subplot = False
number_of_subplots = len(self.subplots)
subplot_domains = get_subplot_domains(number_of_subplots)
absolute_axis_index = 0 # Used to track axes data in the plot
self.figure.layout["plot_bgcolor"] = WHITE
self.figure.layout["font_color"] = BLACK
self.figure.layout["title_x"] = 0.5
xy_common_axis_format = {
"mirror": True,
"linecolor": BLACK,
"linewidth": grid_line_width,
"zeroline": False,
}
x_axis_label = f"{self.x_axis.name}"
if isinstance(self.x_axis, DimensionalData):
x_axis_label += f" [{self.x_axis.display_units}]"
Expand All @@ -221,6 +263,18 @@ def finalize_plot(self):
y_axis_id = absolute_axis_index + 1
for display_data in axis.display_data_set:
at_least_one_subplot = True
y_values = koozie.convert(
display_data.data_values,
display_data.native_units,
axis.units,
)
axis.range_min = min(min(y_values), axis.range_min)
if display_data.y_axis_min is not None:
data_y_axis_min = koozie.convert(
display_data.y_axis_min, display_data.native_units, axis.units
)
axis.range_min = min(data_y_axis_min, axis.range_min)
axis.range_max = max(max(y_values), axis.range_max)
if display_data.x_axis is None:
if isinstance(display_data.x_axis, DimensionalData):
x_axis_values = koozie.convert(
Expand Down Expand Up @@ -248,16 +302,12 @@ def finalize_plot(self):
self.figure.add_trace(
Scatter(
x=x_axis_values,
y=koozie.convert(
display_data.data_values,
display_data.native_units,
axis.units,
),
y=y_values,
name=display_data.name,
yaxis=f"y{y_axis_id}",
xaxis=f"x{x_axis_id}",
mode=display_data.line_properties.get_line_mode(),
visible=("legendonly" if not display_data.line_properties.is_visible else True),
visible=("legendonly" if not display_data.is_visible else True),
line={
"color": display_data.line_properties.color,
"dash": display_data.line_properties.line_type,
Expand All @@ -272,6 +322,8 @@ def finalize_plot(self):
"width": display_data.line_properties.marker_line_width,
},
},
legendgroup=display_data.legend_group,
legendgrouptitle={"text": display_data.legend_group},
),
)
is_base_y_axis = subplot_base_y_axis_id == y_axis_id
Expand All @@ -283,7 +335,12 @@ def finalize_plot(self):
"overlaying": (f"y{subplot_base_y_axis_id}" if not is_base_y_axis else None),
"tickmode": "sync" if not is_base_y_axis else None,
"autoshift": True if axis_number > 1 else None,
"showgrid": True,
"gridcolor": GREY,
"gridwidth": grid_line_width,
"range": axis.get_axis_range(axis.range_min, axis.range_max),
}
self.figure.layout[f"yaxis{y_axis_id}"].update(xy_common_axis_format)
absolute_axis_index += 1
y_axis_side = "right" if y_axis_side == "left" else "left"
is_last_subplot = subplot_number == number_of_subplots
Expand All @@ -293,7 +350,12 @@ def finalize_plot(self):
"domain": [0.0, 1.0],
"matches": (f"x{number_of_subplots}" if subplot_number < number_of_subplots else None),
"showticklabels": None if is_last_subplot else False,
"ticks": None if not is_last_subplot else "outside",
"tickson": None if not is_last_subplot else "boundaries",
"tickcolor": None if not is_last_subplot else BLACK,
"tickwidth": None if not is_last_subplot else grid_line_width,
}
self.figure.layout[f"xaxis{x_axis_id}"].update(xy_common_axis_format)
else:
warnings.warn(f"Subplot {subplot_number} is unused.")
if not at_least_one_subplot:
Expand Down
5 changes: 3 additions & 2 deletions dimes/timeseries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for plotting time series data."""

from .common import DimensionalPlot, DisplayData
from typing import Union


class TimeSeriesData(DisplayData):
Expand All @@ -12,7 +13,7 @@ class TimeSeriesData(DisplayData):
class TimeSeriesPlot(DimensionalPlot):
"""Time series plot."""

def __init__(self, time_values: list):
super().__init__(time_values)
def __init__(self, time_values: list, title: Union[str, None] = None):
super().__init__(time_values, title)
self.add_time_series = self.add_display_data
self.time_values = self.x_axis.data_values
79 changes: 64 additions & 15 deletions test/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from pathlib import Path
import pytest
from dimes import TimeSeriesPlot, TimeSeriesData, LineProperties
from dimes import TimeSeriesPlot, TimeSeriesData, LineProperties, LinesOnly
from dimes.common import DimensionalAxis

TESTING_DIRECTORY = Path("test_outputs")
TESTING_DIRECTORY.mkdir(exist_ok=True)


def test_basic_plot():
"""Test basic plot"""
plot = TimeSeriesPlot([1, 2, 3, 4, 5])
plot = TimeSeriesPlot([1, 2, 3, 4, 5], "Title Basic Plot")
plot.add_time_series(TimeSeriesData([x**2 for x in plot.time_values]))
plot.add_time_series(TimeSeriesData([x**3 for x in plot.time_values]))

Expand Down Expand Up @@ -65,24 +66,15 @@ def test_multi_plot():
plot = TimeSeriesPlot([1, 2, 3, 4, 5])
# Time series & axis names explicit, subplot default to 1
plot.add_time_series(
TimeSeriesData(
[x**2 for x in plot.time_values], name="Power", native_units="hp", display_units="W"
),
TimeSeriesData([x**2 for x in plot.time_values], name="Power", native_units="hp", display_units="W"),
axis_name="Power or Capacity",
)
# Time series name explicit, axis automatically determined by dimensionality, subplot default to 1
plot.add_time_series(
TimeSeriesData(
[x * 10 for x in plot.time_values],
name="Capacity",
native_units="kBtu/h",
is_visible=False,
)
TimeSeriesData([x * 10 for x in plot.time_values], name="Capacity", native_units="kBtu/h", is_visible=False)
)
# Time series and axis will get name from dimensionality, subplot default to 1, new axis for new dimension on existing subplot
plot.add_time_series(
TimeSeriesData([x for x in plot.time_values], native_units="ft", display_units="cm")
)
plot.add_time_series(TimeSeriesData([x for x in plot.time_values], native_units="ft", display_units="cm"))
# Time series & axis names and subplot number are all explicit
plot.add_time_series(
TimeSeriesData([x**3 for x in plot.time_values], name="Number of Apples"),
Expand Down Expand Up @@ -121,7 +113,7 @@ def test_basic_marker():


def test_missing_marker_symbol():
"""Test missing marker symbol, default symbol should be 'circle'"""
"""Test missing marker symbol, default symbol should be 'circle'."""
plot = TimeSeriesPlot([1, 2, 3, 4, 5])
plot.add_time_series(
TimeSeriesData(
Expand All @@ -135,3 +127,60 @@ def test_missing_marker_symbol():
)
)
plot.write_html_plot(Path(TESTING_DIRECTORY, "missing_marker_symbol.html"))


def test_legend_group():
"""Test legend group and legend group title."""
plot = TimeSeriesPlot([1, 2, 3, 4, 5])
city_data = {
"City_A": {2000: [x**2 for x in plot.time_values], 2010: [x**3 for x in plot.time_values]},
"City_B": {2000: [x**2.5 for x in plot.time_values], 2010: [x**3.5 for x in plot.time_values]},
}
for city, year_data in city_data.items():
for year, data in year_data.items():
plot.add_time_series(
TimeSeriesData(
data,
name=city,
legend_group=str(year),
),
)
plot.write_html_plot(Path(TESTING_DIRECTORY, "legend_group.html"))


def test_is_visible():
"""Test visibility of lines in plot and legend."""
plot = TimeSeriesPlot([1, 2, 3, 4, 5])
plot.add_time_series(
TimeSeriesData(
[x**2 for x in plot.time_values],
line_properties=LineProperties(
color="blue", marker_size=5, marker_line_color="black", marker_fill_color="white", marker_line_width=1.5
),
is_visible=True,
name="Visible",
)
)
plot.add_time_series(
TimeSeriesData(
[x**3 for x in plot.time_values],
line_properties=LinesOnly(
color="green",
marker_size=5,
marker_line_color="black",
marker_fill_color="white",
),
is_visible=False,
name="Legend Only",
)
)
plot.write_html_plot(Path(TESTING_DIRECTORY, "is_visible.html"))


def test_get_axis_range():
checks = [([0, 2], [0, 2]), ([0, 23.5], [0, 25])]

for check in checks:
min_value = check[0][0]
max_value = check[0][1]
assert DimensionalAxis.get_axis_range(min_value, max_value) == check[1]

0 comments on commit 6e62beb

Please sign in to comment.