Skip to content

Commit

Permalink
Add drawbase & MAE plotting palettes, main scripts for plots drawing …
Browse files Browse the repository at this point in the history
…(draw_all_mae, draw_all_time), pysindy time stats script
  • Loading branch information
LisIva committed Dec 8, 2023
1 parent bb83610 commit b1f5044
Show file tree
Hide file tree
Showing 12 changed files with 620 additions and 0 deletions.
60 changes: 60 additions & 0 deletions draw_all_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from drawbase.plot_mae import plot_mae

'''
The variable DATA has the following possible states:
wave;
burgers;
kdv;
kdv_sindy;
burgers_sindy;
which corresponds to 5 input datasets and equations
'''

if __name__ == '__main__':
DATA = "wave"

if DATA == "wave":
path = 'data_wave/'
names = ["0", "2e-5", "2.5e-5", "3e-5", "3.2e-5", "3.47e-5"]
core_values = [0.0087, 0.0094, 0.0449, 0.0453]
core_colors = ["#385623", "#89BF65", "#C5E0B3", "#E2EFD9"]
log_transform = False
decimals = 4
n_df = 2
elif DATA == "burgers":
path = 'data_burg/'
names = ["0", "1e-5", "1.5e-5", "2e-5", "2.5e-5", "3e-5", "3.67e-5"]
core_values = [0.0006, 0.0008, 0.001, 0.0014, 0.0132, 0.0137, 0.014, 0.0146]
core_colors = ["#385623", "#538135", "#669D41", "#71AE48", "#89BF65", "#A8D08D", "#C5E0B3", "#E2EFD9"]
log_transform = False
decimals = 4
n_df = 2
elif DATA == "burgers_sindy":
path = 'data_burg_sindy/'
names = ["0", "0.001", "0.005", "0.01", "0.02", "0.03"]
core_values = [8.0e-05, 2.0e-03, 6.0e-03, 1.3e-02, 2.1e-02, 1.0e-01, 2.0e-01]
core_colors = ["#385623", "#43682A", "#538135", "#71AE48", "#C5E0B3", "#E2EFD9", "#F4F9F1"]
log_transform = False
decimals = [3, 1]
n_df = 3
elif DATA == "kdv_sindy":
path = 'data_kdv_sindy/'
names = ["0", "1e-5", "3.5e-5", "5.5e-5", "8e-5", "0.0001", "2.26e-4"]
core_values = [3.0e-03, 8.0e-03, 1.0e-02, 1.4e-02, 1.9e-02, 2.1e-02, 4.4e-02, 1.6, 3.2]
core_colors = ["#385623", "#43682A", "#5A8B39", "#669D41", "#89BF65", "#A8D08D", "#C5E0B3", "#E2EFD9", "#FDFEFC"]
log_transform = True
decimals = [3, 1]
n_df = 3
elif DATA == "kdv":
path = 'data_kdv/'
names = ["0", "0.001", "0.01", "0.07", "0.08", "0.09", "0.092"]
core_values = [1.2e-05, 1.0e-04, 1.0e-03, 1.4e-03]
core_colors = ["#385623", "#669D41", "#A8D08D", "#E2EFD9"]
log_transform = False
decimals = 4
n_df = 2
else:
raise NameError('Unknown equation type')
plot_mae(path, names, core_values, core_colors, decimals=decimals, n_df=n_df, log_transform=log_transform)
39 changes: 39 additions & 0 deletions draw_all_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import seaborn as sns
from matplotlib import rcParams
from drawbase.plot_time import plot_time
rcParams.update({'figure.autolayout': True})
sns.set(style="whitegrid", color_codes=True)

'''
The variable DATA has the following possible states:
wave;
burgers;
kdv;
kdv_sindy;
burgers_sindy;
which corresponds to 5 input datasets and equations
'''
if __name__ == '__main__':
DATA = "wave"

n_df = 2
if DATA == "wave":
path = 'data_wave/'
names = ["0", "2e-5", "2.5e-5", "3e-5", "3.2e-5", "3.47e-5"]
elif DATA == "burgers":
path = 'data_burg/'
names = ["0", "1e-5", "1.5e-5", "2e-5", "2.5e-5", "3e-5", "3.67e-5"]
elif DATA == "burgers_sindy":
path = 'data_burg_sindy/'
names = ["0", "0.001", "0.005", "0.01", "0.02", "0.03"]
elif DATA == "kdv_sindy":
path = 'data_kdv_sindy/'
names = ["0", "1e-5", "3.5e-5", "5.5e-5", "8e-5", "0.0001", "2.26e-4"]
elif DATA == "kdv":
path = 'data_kdv/'
names = ["0", "0.001", "0.01", "0.07", "0.08", "0.09", "0.092"]
else:
raise NameError('Unknown equation type')
plot_time(path, names, n_df)
98 changes: 98 additions & 0 deletions drawbase/plot_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib as mpl
import seaborn as sns
import numpy as np
from drawbase.read_compile_df import read_compile_mae_df
from drawbase.preprocess_df import melt_count_mae
sns.set(style="whitegrid", color_codes=True)


def plot_mae(path, names, core_values, core_colors, decimals=4, n_df=2, log_transform=False):
count_ls = read_compile_mae_df(path, names, decimals=decimals, n_df=n_df)
dfall = melt_count_mae(count_ls)

idxs, color_keys = [], []
cmap, norm, color_dict, mae = _create_cmap(core_values, core_colors, count_ls, log_transform)
for i, g in enumerate(dfall.groupby("variable")):
ax = sns.barplot(data=g[1],
x="index",
y="vcs",
hue="Name",
zorder=-i,
edgecolor="k")
ax.set_axisbelow(True)
color_keys += [g[0]] * len(g[1])
idxs += list(g[1].Name.values)

ax.legend_.remove()
plt.grid(False)
# hatches = ['-', '+', 'x', '\\', '*', 'o']
for j, thisbar in enumerate(ax.patches):
thisbar._facecolor = tuple(color_dict.get(color_keys[j]))
if idxs[j] == "Modified":
thisbar.set_hatch("\\")
elif idxs[j] == "Pysindy":
thisbar.set_hatch("-")

keys_sm = list(color_dict.keys())[:len(color_dict) - 1] # listed
vals_sm = list(color_dict.values())[:len(color_dict) - 1] # colors
cmap_sm = mpl.colors.ListedColormap(vals_sm)
bounds = keys_sm.copy()
bounds.append(keys_sm[-1] + keys_sm[-1] - keys_sm[-2])
norm_sm = mpl.colors.BoundaryNorm(bounds, cmap_sm.N)

sm = plt.cm.ScalarMappable(cmap=cmap_sm, norm=norm_sm)
sm.set_array([])
cbar = plt.colorbar(sm, ticks=keys_sm, boundaries=bounds)

locations = []
for i in range(0, len(bounds) - 1):
locations.append((bounds[i] + bounds[i + 1]) / 2)
cbar.set_ticks(locations)
cbar.ax.set_yticklabels(keys_sm, fontsize=24)

n = []
for i in range(len(dfall.Name.unique())):
if i == 2:
n.append(ax.bar(0, 0, color="#BFBFBF", hatch="-", edgecolor="k"))
elif i == 1:
n.append(ax.bar(0, 0, color="#BFBFBF", hatch="\\", edgecolor="k"))
else:
n.append(ax.bar(0, 0, color="#BFBFBF", edgecolor="k"))

plt.legend(n, dfall.Name.unique(), loc=[0.83, 0.84])
plt.setp(ax.get_legend().get_texts(), fontsize='24')
ax.set_ylabel("No. of runs", fontsize=24)
ax.set_xlabel("Magnitude", fontsize=24)
plt.xticks(fontsize=24, rotation=0)
plt.yticks(fontsize=24)
plt.show()


def _create_cmap(core_values, core_colors, count_ls, log_transform, plot_map=False):
listed = list(count_ls[0].columns)
listed.pop()
listed.pop()
if log_transform:
cat_lg = np.log(listed)
core_values_lg = np.log(core_values)
norm = plt.Normalize(min(cat_lg), max(cat_lg))
tuples = list(zip(map(norm, core_values_lg), core_colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
colors = cmap(norm(cat_lg))
col_dict = dict(zip(listed, colors))
else:
norm = plt.Normalize(min(listed), max(listed))
tuples = list(zip(map(norm, core_values), core_colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
colors = cmap(norm(listed))
col_dict = dict(zip(listed, colors))
col_dict["n/a"] = np.array([191. / 255, 191. / 255, 191. / 255, 1.0])

if plot_map:
x = np.linspace(0, 6, len(listed))
y = np.zeros((len(listed),))
plt.scatter(x, y, c=listed, cmap=cmap, norm=norm, s=90)
plt.show()
return cmap, norm, col_dict, listed
68 changes: 68 additions & 0 deletions drawbase/plot_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import seaborn as sns
import matplotlib.pyplot as plt
from drawbase.preprocess_df import melt_count_time
sns.set(style="whitegrid", color_codes=True)


def plot_time(path, names, n_df):
dfall = melt_count_time(path, names, n_df)
lbl_y = "s"
if dfall.time.max() > 150:
lbl_y = "m"

ax = sns.boxplot(x=dfall['Magnitude'],
y=dfall['time'],
hue=dfall['Algorithm'],
showfliers=False)
plt.xticks(fontsize=24, rotation=0)
plt.yticks(fontsize=24)
plt.setp(ax.get_legend().get_texts(), fontsize='24')
plt.setp(ax.get_legend().get_title(), fontsize='24')
ax.set_ylabel(f"Time, {lbl_y}", fontsize=24)
ax.set_xlabel("Magnitude", fontsize=24)
plt.show()


def plot_sindy(path, names, n_df):
dfall = melt_count_time(path, names, n_df)

ratio_k = 4
fig, (ax1, ax2) = plt.subplots(figsize=(16, 8), ncols=1, nrows=2, sharex=True,
gridspec_kw={'hspace': 0.07, 'height_ratios': [ratio_k, 1]})
sns.boxplot(x=dfall['Magnitude'],
y=dfall['time'],
hue=dfall['Algorithm'],
orient="v",
showfliers=False,
ax=ax1)
sns.boxplot(x=dfall['Magnitude'],
y=dfall['time'],
hue=dfall['Algorithm'],
orient="v",
showfliers=False,
ax=ax2)
ax1.set_ylim(24.2, 31.7)
ax2.set_ylim(0.022, 0.028)

d = .005
kwargs = dict(transform=ax1.transAxes, color="k", clip_on=False)
ax1.plot((-d, +d), (-d, +d), **kwargs)
ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update(transform=ax2.transAxes)
ax2.plot((-d, +d), (1 - ratio_k * d, 1 + ratio_k * d), **kwargs)
ax2.plot((1 - d, 1 + d), (1 - ratio_k * d, 1 + ratio_k * d), **kwargs)

ax1.tick_params(axis='both', which='major', labelsize=24)
ax1.tick_params(axis='both', which='minor', labelsize=24)

ax2.tick_params(axis='both', which='major', labelsize=24)
ax2.tick_params(axis='both', which='minor', labelsize=24)
ax2.tick_params(axis='x', labelrotation=25, bottom=True)

ax1.grid()
ax2.grid()
fig.text(0.065, 0.5, "time, s", va="center", rotation="vertical", fontsize=24)

ax1.xaxis.tick_bottom()
plt.subplots_adjust(bottom=0.189, top=1.)
plt.show()
44 changes: 44 additions & 0 deletions drawbase/preprocess_df.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pandas as pd
import numpy as np
from drawbase.read_compile_df import read_csv


def melt_count_mae(count_ls):

count_ls[0]["Name"] = "Classical"
count_ls[1]["Name"] = "Modified"
if len(count_ls) == 3:
count_ls[2]["Name"] = "Pysindy"

dfall = pd.concat([pd.melt(i.reset_index(),
id_vars=["Name", "index"])
for i in count_ls],
ignore_index=True)
dfall.set_index(["Name", "index", "variable"], inplace=True)
dfall["vcs"] = dfall.groupby(level=["Name", "index"]).cumsum()
dfall.reset_index(inplace=True)

for i in range(len(dfall)):
if np.isnan(dfall.loc[i, "variable"]):
dfall.loc[i, "variable"] = "n/a"
return dfall


def melt_count_time(path, names, n_df):
df_lsls = read_csv(path, names, n_df)
for df in df_lsls[0]:
df["Algorithm"] = "Classical"
for df in df_lsls[1]:
df["Algorithm"] = "Modified"
if n_df == 3:
for df in df_lsls[2]:
df["Algorithm"] = "Pysindy"

subdf_lsls = []
for i in range(len(df_lsls)):
for j in range(len(names)):
temp_df = df_lsls[i][j][["time", "Algorithm"]]
temp_df["Magnitude"] = names[j]
subdf_lsls.append(temp_df)
dfall = pd.concat(subdf_lsls, ignore_index=True)
return dfall
Loading

0 comments on commit b1f5044

Please sign in to comment.