Skip to content

Commit

Permalink
Add options to plot the voxel label counts
Browse files Browse the repository at this point in the history
  • Loading branch information
Schellenberg3 committed Nov 18, 2023
1 parent 8c1570b commit 6477fac
Showing 1 changed file with 90 additions and 33 deletions.
123 changes: 90 additions & 33 deletions scripts/figures/plot_experiment_result_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,55 @@
PROJECT_ROOT_PATH = pathlib.Path(__file__).parent.parent.resolve().parent


## --------------------------------------- DEFINE PLOTS ---------------------------------------- ##


def get_probability_figure(part: str, policy: str, plot_type: str,
reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]:
"""
Returns multiple matplotlib figures with different titles.
"""
plots = []
for label in reconstructions:
fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200)

# fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ")
# fig.suptitle(fig_title)

ax.set_title(f"Reconstruction {plot_type}")
# ax.set_xlabel("Views Added")
ax.set_ylim(0, 1.1)
ax.grid(axis='y', which='major', color='0.80')
ax.grid(axis='y', which='minor', color='0.95')

ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_ticks(list(np.linspace(0.2, 1, 5)))
ax.yaxis.set_minor_locator(AutoMinorLocator())

plots.append((fig, ax))
return plots


def get_quantity_figure(part: str, policy: str, plot_type: str,
reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]:
"""
"""
plots = []
for label in reconstructions:
fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200)

# fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ")
# fig.suptitle(fig_title)

ax.set_title(f"Reconstruction {plot_type}")
ax.grid(axis='y', which='major', color='0.80')
ax.grid(axis='y', which='minor', color='0.95')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

plots.append((fig, ax))
return plots


## ------------------------------- CONFUSION MATRIX CALCULATIONS ------------------------------- ##

# Index location in the extracted Nx4 matrix
Expand Down Expand Up @@ -111,14 +160,45 @@ def arr_miss_rate(arr: np.ndarray):
return miss_rate(arr[TP], arr[FN])


def arr_true_positive(arr: np.ndarray):
"""
Returns the count of voxels labeled as true positive.
"""
return arr[TP]

def arr_true_negative(arr: np.ndarray):
"""
Returns the count of voxels labeled as true negative.
"""
return arr[TN]

def arr_false_positive(arr: np.ndarray):
"""
Returns the count of voxels labeled as false positive.
"""
return arr[FP]

def arr_false_negative(arr: np.ndarray):
"""
Returns the count of voxels labeled as false negative.
"""
return arr[FN]



PLOT_OPTIONS = {
"accuracy" : arr_accuracy,
"precision" : arr_precision,
"sensitivity" : arr_sensitivity,
"specificity" : arr_specificity,
"balanced-accuracy" : arr_balanced_accuracy,
"fall-out" : arr_fall_out,
"miss-rate" : arr_miss_rate,
"accuracy" : (get_probability_figure, arr_accuracy),
"precision" : (get_probability_figure, arr_precision),
"sensitivity" : (get_probability_figure, arr_sensitivity),
"specificity" : (get_probability_figure, arr_specificity),
"balanced-accuracy" : (get_probability_figure, arr_balanced_accuracy),
"fall-out" : (get_probability_figure, arr_fall_out),
"miss-rate" : (get_probability_figure, arr_miss_rate),

"true-positive" : (get_quantity_figure, arr_true_positive),
"true-negative" : (get_quantity_figure, arr_true_negative),
"false-positive" : (get_quantity_figure, arr_false_positive),
"false-negative" : (get_quantity_figure, arr_false_negative),
}


Expand All @@ -134,27 +214,6 @@ def get_metric_group(hdf5_path: pathlib.Path) -> h5py.Group:
return h5_file["Metric"]


def get_figures(part: str, policy: str, plot_type: str, reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]:
"""
Returns multiple matplotlib figures with different titles.
"""
plots = []
for label in reconstructions:
fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200)

# fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ")
# fig.suptitle(fig_title)

ax.set_title(f"Reconstruction {plot_type}")
# ax.set_xlabel("Views Added")
ax.set_ylim(0, 1.1)
ax.grid(axis='y', which='major', color='0.80')
ax.grid(axis='y', which='minor', color='0.95')

plots.append((fig, ax))
return plots


def select_experiment(dir : pathlib.Path) -> pathlib.Path:
"""
Selects one from the possible experiment directories.
Expand Down Expand Up @@ -185,7 +244,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot

part_name = policy_path.parent.name.capitalize()
policy_name = policy_path.name
plots = get_figures(part_name, policy_name, plot_key, confusion_grid_labels)
plots = PLOT_OPTIONS[plot_key][0](part_name, policy_name, plot_key, confusion_grid_labels)

try:
# Sort in increasing order on integer number.
Expand All @@ -211,7 +270,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot
for g, group in enumerate(confusion_groups):
confusion_group = metric_group[group]
data[g][:, :, i] = confusion_group.get("data")[:, 1:5]
result[g][1:, i] = np.apply_along_axis(PLOT_OPTIONS[plot_key], 1, data[g][:, :, i])
result[g][1:, i] = np.apply_along_axis(PLOT_OPTIONS[plot_key][1], 1, data[g][:, :, i])

if len(reps_dirs) > 1:
line_label += " (Average)"
Expand All @@ -230,9 +289,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot

for g in range(n_group):
# plots[g][1].legend(loc='right')
plots[g][1].xaxis.set_major_locator(MaxNLocator(integer=True))
plots[g][1].yaxis.set_ticks(list(np.linspace(0.2, 1, 5)))
plots[g][1].yaxis.set_minor_locator(AutoMinorLocator())

# plots[g][1].tick_params(axis='y', which='minor', bottom=False)
# plots[g][1].yaxis.set_minor_locator(MultipleLocator(3))

Expand Down

0 comments on commit 6477fac

Please sign in to comment.