diff --git a/scripts/figures/plot_experiment_result_confusion.py b/scripts/figures/plot_experiment_result_confusion.py index 096c053..62d8667 100644 --- a/scripts/figures/plot_experiment_result_confusion.py +++ b/scripts/figures/plot_experiment_result_confusion.py @@ -10,6 +10,55 @@ PROJECT_ROOT_PATH = pathlib.Path(__file__).parent.parent.resolve().parent +## --------------------------------------- DEFINE PLOTS ---------------------------------------- ## + + +def get_probability_figure(part: str, policy: str, plot_type: str, + reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]: + """ + Returns multiple matplotlib figures with different titles. + """ + plots = [] + for label in reconstructions: + fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200) + + # fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ") + # fig.suptitle(fig_title) + + ax.set_title(f"Reconstruction {plot_type}") + # ax.set_xlabel("Views Added") + ax.set_ylim(0, 1.1) + ax.grid(axis='y', which='major', color='0.80') + ax.grid(axis='y', which='minor', color='0.95') + + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_ticks(list(np.linspace(0.2, 1, 5))) + ax.yaxis.set_minor_locator(AutoMinorLocator()) + + plots.append((fig, ax)) + return plots + + +def get_quantity_figure(part: str, policy: str, plot_type: str, + reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]: + """ + """ + plots = [] + for label in reconstructions: + fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200) + + # fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ") + # fig.suptitle(fig_title) + + ax.set_title(f"Reconstruction {plot_type}") + ax.grid(axis='y', which='major', color='0.80') + ax.grid(axis='y', which='minor', color='0.95') + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + + plots.append((fig, ax)) + return plots + + ## ------------------------------- CONFUSION MATRIX CALCULATIONS ------------------------------- ## # Index location in the extracted Nx4 matrix @@ -111,14 +160,45 @@ def arr_miss_rate(arr: np.ndarray): return miss_rate(arr[TP], arr[FN]) +def arr_true_positive(arr: np.ndarray): + """ + Returns the count of voxels labeled as true positive. + """ + return arr[TP] + +def arr_true_negative(arr: np.ndarray): + """ + Returns the count of voxels labeled as true negative. + """ + return arr[TN] + +def arr_false_positive(arr: np.ndarray): + """ + Returns the count of voxels labeled as false positive. + """ + return arr[FP] + +def arr_false_negative(arr: np.ndarray): + """ + Returns the count of voxels labeled as false negative. + """ + return arr[FN] + + + PLOT_OPTIONS = { - "accuracy" : arr_accuracy, - "precision" : arr_precision, - "sensitivity" : arr_sensitivity, - "specificity" : arr_specificity, - "balanced-accuracy" : arr_balanced_accuracy, - "fall-out" : arr_fall_out, - "miss-rate" : arr_miss_rate, + "accuracy" : (get_probability_figure, arr_accuracy), + "precision" : (get_probability_figure, arr_precision), + "sensitivity" : (get_probability_figure, arr_sensitivity), + "specificity" : (get_probability_figure, arr_specificity), + "balanced-accuracy" : (get_probability_figure, arr_balanced_accuracy), + "fall-out" : (get_probability_figure, arr_fall_out), + "miss-rate" : (get_probability_figure, arr_miss_rate), + + "true-positive" : (get_quantity_figure, arr_true_positive), + "true-negative" : (get_quantity_figure, arr_true_negative), + "false-positive" : (get_quantity_figure, arr_false_positive), + "false-negative" : (get_quantity_figure, arr_false_negative), } @@ -134,27 +214,6 @@ def get_metric_group(hdf5_path: pathlib.Path) -> h5py.Group: return h5_file["Metric"] -def get_figures(part: str, policy: str, plot_type: str, reconstructions: list[str]) -> list[tuple[plt.Figure, plt.Axes]]: - """ - Returns multiple matplotlib figures with different titles. - """ - plots = [] - for label in reconstructions: - fig, ax = plt.subplots(figsize=[5.5, 3.75], dpi=200) - - # fig_title = f"{label} Reconstruction of {part} using {policy}".replace("_", " ") - # fig.suptitle(fig_title) - - ax.set_title(f"Reconstruction {plot_type}") - # ax.set_xlabel("Views Added") - ax.set_ylim(0, 1.1) - ax.grid(axis='y', which='major', color='0.80') - ax.grid(axis='y', which='minor', color='0.95') - - plots.append((fig, ax)) - return plots - - def select_experiment(dir : pathlib.Path) -> pathlib.Path: """ Selects one from the possible experiment directories. @@ -185,7 +244,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot part_name = policy_path.parent.name.capitalize() policy_name = policy_path.name - plots = get_figures(part_name, policy_name, plot_key, confusion_grid_labels) + plots = PLOT_OPTIONS[plot_key][0](part_name, policy_name, plot_key, confusion_grid_labels) try: # Sort in increasing order on integer number. @@ -211,7 +270,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot for g, group in enumerate(confusion_groups): confusion_group = metric_group[group] data[g][:, :, i] = confusion_group.get("data")[:, 1:5] - result[g][1:, i] = np.apply_along_axis(PLOT_OPTIONS[plot_key], 1, data[g][:, :, i]) + result[g][1:, i] = np.apply_along_axis(PLOT_OPTIONS[plot_key][1], 1, data[g][:, :, i]) if len(reps_dirs) > 1: line_label += " (Average)" @@ -230,9 +289,7 @@ def plot_policy_sweep(policy_path: pathlib.Path, figure_root: pathlib.Path, plot for g in range(n_group): # plots[g][1].legend(loc='right') - plots[g][1].xaxis.set_major_locator(MaxNLocator(integer=True)) - plots[g][1].yaxis.set_ticks(list(np.linspace(0.2, 1, 5))) - plots[g][1].yaxis.set_minor_locator(AutoMinorLocator()) + # plots[g][1].tick_params(axis='y', which='minor', bottom=False) # plots[g][1].yaxis.set_minor_locator(MultipleLocator(3))