-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add drawbase & MAE plotting palettes, main scripts for plots drawing …
…(draw_all_mae, draw_all_time), pysindy time stats script
- Loading branch information
Showing
12 changed files
with
620 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from drawbase.plot_mae import plot_mae | ||
|
||
''' | ||
The variable DATA has the following possible states: | ||
wave; | ||
burgers; | ||
kdv; | ||
kdv_sindy; | ||
burgers_sindy; | ||
which corresponds to 5 input datasets and equations | ||
''' | ||
|
||
if __name__ == '__main__': | ||
DATA = "wave" | ||
|
||
if DATA == "wave": | ||
path = 'data_wave/' | ||
names = ["0", "2e-5", "2.5e-5", "3e-5", "3.2e-5", "3.47e-5"] | ||
core_values = [0.0087, 0.0094, 0.0449, 0.0453] | ||
core_colors = ["#385623", "#89BF65", "#C5E0B3", "#E2EFD9"] | ||
log_transform = False | ||
decimals = 4 | ||
n_df = 2 | ||
elif DATA == "burgers": | ||
path = 'data_burg/' | ||
names = ["0", "1e-5", "1.5e-5", "2e-5", "2.5e-5", "3e-5", "3.67e-5"] | ||
core_values = [0.0006, 0.0008, 0.001, 0.0014, 0.0132, 0.0137, 0.014, 0.0146] | ||
core_colors = ["#385623", "#538135", "#669D41", "#71AE48", "#89BF65", "#A8D08D", "#C5E0B3", "#E2EFD9"] | ||
log_transform = False | ||
decimals = 4 | ||
n_df = 2 | ||
elif DATA == "burgers_sindy": | ||
path = 'data_burg_sindy/' | ||
names = ["0", "0.001", "0.005", "0.01", "0.02", "0.03"] | ||
core_values = [8.0e-05, 2.0e-03, 6.0e-03, 1.3e-02, 2.1e-02, 1.0e-01, 2.0e-01] | ||
core_colors = ["#385623", "#43682A", "#538135", "#71AE48", "#C5E0B3", "#E2EFD9", "#F4F9F1"] | ||
log_transform = False | ||
decimals = [3, 1] | ||
n_df = 3 | ||
elif DATA == "kdv_sindy": | ||
path = 'data_kdv_sindy/' | ||
names = ["0", "1e-5", "3.5e-5", "5.5e-5", "8e-5", "0.0001", "2.26e-4"] | ||
core_values = [3.0e-03, 8.0e-03, 1.0e-02, 1.4e-02, 1.9e-02, 2.1e-02, 4.4e-02, 1.6, 3.2] | ||
core_colors = ["#385623", "#43682A", "#5A8B39", "#669D41", "#89BF65", "#A8D08D", "#C5E0B3", "#E2EFD9", "#FDFEFC"] | ||
log_transform = True | ||
decimals = [3, 1] | ||
n_df = 3 | ||
elif DATA == "kdv": | ||
path = 'data_kdv/' | ||
names = ["0", "0.001", "0.01", "0.07", "0.08", "0.09", "0.092"] | ||
core_values = [1.2e-05, 1.0e-04, 1.0e-03, 1.4e-03] | ||
core_colors = ["#385623", "#669D41", "#A8D08D", "#E2EFD9"] | ||
log_transform = False | ||
decimals = 4 | ||
n_df = 2 | ||
else: | ||
raise NameError('Unknown equation type') | ||
plot_mae(path, names, core_values, core_colors, decimals=decimals, n_df=n_df, log_transform=log_transform) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import seaborn as sns | ||
from matplotlib import rcParams | ||
from drawbase.plot_time import plot_time | ||
rcParams.update({'figure.autolayout': True}) | ||
sns.set(style="whitegrid", color_codes=True) | ||
|
||
''' | ||
The variable DATA has the following possible states: | ||
wave; | ||
burgers; | ||
kdv; | ||
kdv_sindy; | ||
burgers_sindy; | ||
which corresponds to 5 input datasets and equations | ||
''' | ||
if __name__ == '__main__': | ||
DATA = "wave" | ||
|
||
n_df = 2 | ||
if DATA == "wave": | ||
path = 'data_wave/' | ||
names = ["0", "2e-5", "2.5e-5", "3e-5", "3.2e-5", "3.47e-5"] | ||
elif DATA == "burgers": | ||
path = 'data_burg/' | ||
names = ["0", "1e-5", "1.5e-5", "2e-5", "2.5e-5", "3e-5", "3.67e-5"] | ||
elif DATA == "burgers_sindy": | ||
path = 'data_burg_sindy/' | ||
names = ["0", "0.001", "0.005", "0.01", "0.02", "0.03"] | ||
elif DATA == "kdv_sindy": | ||
path = 'data_kdv_sindy/' | ||
names = ["0", "1e-5", "3.5e-5", "5.5e-5", "8e-5", "0.0001", "2.26e-4"] | ||
elif DATA == "kdv": | ||
path = 'data_kdv/' | ||
names = ["0", "0.001", "0.01", "0.07", "0.08", "0.09", "0.092"] | ||
else: | ||
raise NameError('Unknown equation type') | ||
plot_time(path, names, n_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import matplotlib.pyplot as plt | ||
import matplotlib.colors | ||
import matplotlib as mpl | ||
import seaborn as sns | ||
import numpy as np | ||
from drawbase.read_compile_df import read_compile_mae_df | ||
from drawbase.preprocess_df import melt_count_mae | ||
sns.set(style="whitegrid", color_codes=True) | ||
|
||
|
||
def plot_mae(path, names, core_values, core_colors, decimals=4, n_df=2, log_transform=False): | ||
count_ls = read_compile_mae_df(path, names, decimals=decimals, n_df=n_df) | ||
dfall = melt_count_mae(count_ls) | ||
|
||
idxs, color_keys = [], [] | ||
cmap, norm, color_dict, mae = _create_cmap(core_values, core_colors, count_ls, log_transform) | ||
for i, g in enumerate(dfall.groupby("variable")): | ||
ax = sns.barplot(data=g[1], | ||
x="index", | ||
y="vcs", | ||
hue="Name", | ||
zorder=-i, | ||
edgecolor="k") | ||
ax.set_axisbelow(True) | ||
color_keys += [g[0]] * len(g[1]) | ||
idxs += list(g[1].Name.values) | ||
|
||
ax.legend_.remove() | ||
plt.grid(False) | ||
# hatches = ['-', '+', 'x', '\\', '*', 'o'] | ||
for j, thisbar in enumerate(ax.patches): | ||
thisbar._facecolor = tuple(color_dict.get(color_keys[j])) | ||
if idxs[j] == "Modified": | ||
thisbar.set_hatch("\\") | ||
elif idxs[j] == "Pysindy": | ||
thisbar.set_hatch("-") | ||
|
||
keys_sm = list(color_dict.keys())[:len(color_dict) - 1] # listed | ||
vals_sm = list(color_dict.values())[:len(color_dict) - 1] # colors | ||
cmap_sm = mpl.colors.ListedColormap(vals_sm) | ||
bounds = keys_sm.copy() | ||
bounds.append(keys_sm[-1] + keys_sm[-1] - keys_sm[-2]) | ||
norm_sm = mpl.colors.BoundaryNorm(bounds, cmap_sm.N) | ||
|
||
sm = plt.cm.ScalarMappable(cmap=cmap_sm, norm=norm_sm) | ||
sm.set_array([]) | ||
cbar = plt.colorbar(sm, ticks=keys_sm, boundaries=bounds) | ||
|
||
locations = [] | ||
for i in range(0, len(bounds) - 1): | ||
locations.append((bounds[i] + bounds[i + 1]) / 2) | ||
cbar.set_ticks(locations) | ||
cbar.ax.set_yticklabels(keys_sm, fontsize=24) | ||
|
||
n = [] | ||
for i in range(len(dfall.Name.unique())): | ||
if i == 2: | ||
n.append(ax.bar(0, 0, color="#BFBFBF", hatch="-", edgecolor="k")) | ||
elif i == 1: | ||
n.append(ax.bar(0, 0, color="#BFBFBF", hatch="\\", edgecolor="k")) | ||
else: | ||
n.append(ax.bar(0, 0, color="#BFBFBF", edgecolor="k")) | ||
|
||
plt.legend(n, dfall.Name.unique(), loc=[0.83, 0.84]) | ||
plt.setp(ax.get_legend().get_texts(), fontsize='24') | ||
ax.set_ylabel("No. of runs", fontsize=24) | ||
ax.set_xlabel("Magnitude", fontsize=24) | ||
plt.xticks(fontsize=24, rotation=0) | ||
plt.yticks(fontsize=24) | ||
plt.show() | ||
|
||
|
||
def _create_cmap(core_values, core_colors, count_ls, log_transform, plot_map=False): | ||
listed = list(count_ls[0].columns) | ||
listed.pop() | ||
listed.pop() | ||
if log_transform: | ||
cat_lg = np.log(listed) | ||
core_values_lg = np.log(core_values) | ||
norm = plt.Normalize(min(cat_lg), max(cat_lg)) | ||
tuples = list(zip(map(norm, core_values_lg), core_colors)) | ||
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples) | ||
colors = cmap(norm(cat_lg)) | ||
col_dict = dict(zip(listed, colors)) | ||
else: | ||
norm = plt.Normalize(min(listed), max(listed)) | ||
tuples = list(zip(map(norm, core_values), core_colors)) | ||
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples) | ||
colors = cmap(norm(listed)) | ||
col_dict = dict(zip(listed, colors)) | ||
col_dict["n/a"] = np.array([191. / 255, 191. / 255, 191. / 255, 1.0]) | ||
|
||
if plot_map: | ||
x = np.linspace(0, 6, len(listed)) | ||
y = np.zeros((len(listed),)) | ||
plt.scatter(x, y, c=listed, cmap=cmap, norm=norm, s=90) | ||
plt.show() | ||
return cmap, norm, col_dict, listed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
from drawbase.preprocess_df import melt_count_time | ||
sns.set(style="whitegrid", color_codes=True) | ||
|
||
|
||
def plot_time(path, names, n_df): | ||
dfall = melt_count_time(path, names, n_df) | ||
lbl_y = "s" | ||
if dfall.time.max() > 150: | ||
lbl_y = "m" | ||
|
||
ax = sns.boxplot(x=dfall['Magnitude'], | ||
y=dfall['time'], | ||
hue=dfall['Algorithm'], | ||
showfliers=False) | ||
plt.xticks(fontsize=24, rotation=0) | ||
plt.yticks(fontsize=24) | ||
plt.setp(ax.get_legend().get_texts(), fontsize='24') | ||
plt.setp(ax.get_legend().get_title(), fontsize='24') | ||
ax.set_ylabel(f"Time, {lbl_y}", fontsize=24) | ||
ax.set_xlabel("Magnitude", fontsize=24) | ||
plt.show() | ||
|
||
|
||
def plot_sindy(path, names, n_df): | ||
dfall = melt_count_time(path, names, n_df) | ||
|
||
ratio_k = 4 | ||
fig, (ax1, ax2) = plt.subplots(figsize=(16, 8), ncols=1, nrows=2, sharex=True, | ||
gridspec_kw={'hspace': 0.07, 'height_ratios': [ratio_k, 1]}) | ||
sns.boxplot(x=dfall['Magnitude'], | ||
y=dfall['time'], | ||
hue=dfall['Algorithm'], | ||
orient="v", | ||
showfliers=False, | ||
ax=ax1) | ||
sns.boxplot(x=dfall['Magnitude'], | ||
y=dfall['time'], | ||
hue=dfall['Algorithm'], | ||
orient="v", | ||
showfliers=False, | ||
ax=ax2) | ||
ax1.set_ylim(24.2, 31.7) | ||
ax2.set_ylim(0.022, 0.028) | ||
|
||
d = .005 | ||
kwargs = dict(transform=ax1.transAxes, color="k", clip_on=False) | ||
ax1.plot((-d, +d), (-d, +d), **kwargs) | ||
ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs) | ||
kwargs.update(transform=ax2.transAxes) | ||
ax2.plot((-d, +d), (1 - ratio_k * d, 1 + ratio_k * d), **kwargs) | ||
ax2.plot((1 - d, 1 + d), (1 - ratio_k * d, 1 + ratio_k * d), **kwargs) | ||
|
||
ax1.tick_params(axis='both', which='major', labelsize=24) | ||
ax1.tick_params(axis='both', which='minor', labelsize=24) | ||
|
||
ax2.tick_params(axis='both', which='major', labelsize=24) | ||
ax2.tick_params(axis='both', which='minor', labelsize=24) | ||
ax2.tick_params(axis='x', labelrotation=25, bottom=True) | ||
|
||
ax1.grid() | ||
ax2.grid() | ||
fig.text(0.065, 0.5, "time, s", va="center", rotation="vertical", fontsize=24) | ||
|
||
ax1.xaxis.tick_bottom() | ||
plt.subplots_adjust(bottom=0.189, top=1.) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from drawbase.read_compile_df import read_csv | ||
|
||
|
||
def melt_count_mae(count_ls): | ||
|
||
count_ls[0]["Name"] = "Classical" | ||
count_ls[1]["Name"] = "Modified" | ||
if len(count_ls) == 3: | ||
count_ls[2]["Name"] = "Pysindy" | ||
|
||
dfall = pd.concat([pd.melt(i.reset_index(), | ||
id_vars=["Name", "index"]) | ||
for i in count_ls], | ||
ignore_index=True) | ||
dfall.set_index(["Name", "index", "variable"], inplace=True) | ||
dfall["vcs"] = dfall.groupby(level=["Name", "index"]).cumsum() | ||
dfall.reset_index(inplace=True) | ||
|
||
for i in range(len(dfall)): | ||
if np.isnan(dfall.loc[i, "variable"]): | ||
dfall.loc[i, "variable"] = "n/a" | ||
return dfall | ||
|
||
|
||
def melt_count_time(path, names, n_df): | ||
df_lsls = read_csv(path, names, n_df) | ||
for df in df_lsls[0]: | ||
df["Algorithm"] = "Classical" | ||
for df in df_lsls[1]: | ||
df["Algorithm"] = "Modified" | ||
if n_df == 3: | ||
for df in df_lsls[2]: | ||
df["Algorithm"] = "Pysindy" | ||
|
||
subdf_lsls = [] | ||
for i in range(len(df_lsls)): | ||
for j in range(len(names)): | ||
temp_df = df_lsls[i][j][["time", "Algorithm"]] | ||
temp_df["Magnitude"] = names[j] | ||
subdf_lsls.append(temp_df) | ||
dfall = pd.concat(subdf_lsls, ignore_index=True) | ||
return dfall |
Oops, something went wrong.