diff --git a/model_evaluation.py b/model_evaluation.py index c3b8513..462e443 100644 --- a/model_evaluation.py +++ b/model_evaluation.py @@ -120,8 +120,7 @@ def main(args, rel_path_to_this_script_dir): reporter = Reporter(output_filepath_template, parsed_args.whole_country) return reporter.report(predictions_list, metrics_list , horizons_list=parsed_args.horizons_list - , date_selector=parsed_args.date_selector - , compare_diff_with_actual=parsed_args.compare_diff_with_actual) + , date_selector=parsed_args.date_selector) if __name__ == '__main__': diff --git a/reporter.py b/reporter.py index 4b21869..ae79918 100644 --- a/reporter.py +++ b/reporter.py @@ -439,7 +439,8 @@ def generate_forecasting_plot(self, dap, save_path_template axe_x = non_nan_plot_info.columns.get_level_values(axis_x) + sdate axe_y = non_nan_plot_info.values.reshape(-1) color = COLORS[(line_n - 1) % len(COLORS)] - axe.plot(axe_x, axe_y, '-' + PLOT_MARKERS[2 * (line_n - 1) % len(PLOT_MARKERS)] + axe.plot(axe_x, axe_y + , '-' + PLOT_MARKERS[2 * (line_n - 1) % len(PLOT_MARKERS)] , label=line_label, c=color) axe.axvline(x=sdate.to_pydatetime(), ymin=0.0, ymax=1.0 @@ -469,7 +470,7 @@ def generate_compairing_plot(dap, save_path_template , line_labels_basis=None , line_filter=None , horizons_list=None - , compare_diff_with_actual=None): + , compare_with_actual_by=None): if file_splitter is None: file_splitter = 'PredType' if axis_x is None: @@ -492,7 +493,7 @@ def generate_compairing_plot(dap, save_path_template new_columns = info.columns.values int_1 = lambda x: (x[0], int(x[1]), x[2]) new_columns = [int_1(col.split('_')) if '_' in col else (col, 0, 'Actual') - for col in new_columns] + for col in new_columns] new_columns = pd.MultiIndex.from_tuples(new_columns , names=['PredType', 'Horizon', 'Model']) @@ -502,7 +503,7 @@ def generate_compairing_plot(dap, save_path_template if not(horizons_list is None): info = info.loc[:, pd.IndexSlice[:, horizons_list, :]] - + for file_name, file_info in info.groupby(level=file_splitter, axis=1): fig, axes = plt.subplots( file_info.columns.get_level_values(axis_y_splitter).nunique() @@ -524,31 +525,29 @@ def generate_compairing_plot(dap, save_path_template )): if isinstance(line_label, tuple): line_label = '_'.join(map(str, line_label)) - + color = COLORS[(line_n - 1) % len(COLORS)] plot_info = plot_info[plot_info.columns[0]] mask = plot_info.notna() dates_mask |= mask if mask.sum().sum() == 0: - #line_label += ' (No data)' + line_label += ' (No data)' continue non_nan_plot_info = plot_info[mask] axe_x = non_nan_plot_info.index.get_level_values(axis_x) axe_y = non_nan_plot_info.values.reshape(-1) - if compare_diff_with_actual: - # print(location, file_name, axis_y_name, line_label, compare_diff_with_actual.SHORT_TRANSFORM_NAME) - axe_y = compare_diff_with_actual.transform( + if compare_with_actual_by: + axe_y = compare_with_actual_by.transform( data.loc[non_nan_plot_info.index , pd.IndexSlice[file_name, 0, 'Actual']].values , axe_y ) - #axe_y = axe_y - data.loc[non_nan_plot_info.index - # , pd.IndexSlice[file_name, 0, 'Actual']] - axe.plot(axe_x, axe_y, '-' + PLOT_MARKERS[2 * ((line_n - 1) // 3) % len(PLOT_MARKERS)] - , label=line_label, c=color) + axe.plot(axe_x, axe_y + , '-' + PLOT_MARKERS[2 * ((line_n - 1) // 3) % len(PLOT_MARKERS)] + , label=line_label, c=color) - if not(compare_diff_with_actual): + if not(compare_with_actual_by): sel_data = data.loc[:, pd.IndexSlice[file_name, 0, 'Actual']][dates_mask] data_x = sel_data.index.get_level_values(axis_x) data_y = sel_data.values.reshape(-1) @@ -561,19 +560,20 @@ def generate_compairing_plot(dap, save_path_template axe.grid() axe.set_title( - "Plots over predict dates of covid-19 for {} {} and {}".format(location, file_name - , axis_y_displayname) + "Plots over predict dates of covid-19 for {} {} and {}".format( + location, file_name, axis_y_displayname + ) , fontsize=25 ) - pretext = (compare_diff_with_actual.TRANSFORM_NAME - if compare_diff_with_actual else 'Cumulative cases') + pretext = (compare_with_actual_by.TRANSFORM_NAME + if compare_with_actual_by else 'Cumulative cases') axe.set_ylabel(pretext + ' for {}'.format(axis_y_displayname) , fontsize=20) plt.tight_layout() path_components = [save_path_template, location, file_name] - if compare_diff_with_actual: - path_components.append(compare_diff_with_actual.SHORT_TRANSFORM_NAME) + if compare_with_actual_by: + path_components.append(compare_with_actual_by.SHORT_TRANSFORM_NAME) image_name = '_'.join(path_components) + '.png' fig.savefig(image_name) plt.close() @@ -582,7 +582,7 @@ def generate_compairing_plot(dap, save_path_template return generated_files def report(self, predictions_list, metrics_list - , horizons_list=None, date_selector=None, compare_diff_with_actual=None): + , horizons_list=None, date_selector=None): print('Preparing data for report') merged_data_and_predictions, merged_orig_dap = self.prepare_data_and_prediction_for_report( predictions_list @@ -622,11 +622,11 @@ def report(self, predictions_list, metrics_list generated_files += self.generate_forecasting_plot(original_dap_for_report , self.opt + '_forecast' , date_selector=date_selector) - + generated_files += self.generate_compairing_plot( dap_for_report , self.opt + '_comparison' - , compare_diff_with_actual=None + , compare_with_actual_by=None , horizons_list=horizons_list ) @@ -635,7 +635,7 @@ def report(self, predictions_list, metrics_list generated_files += self.generate_compairing_plot( dap_for_report , self.opt + '_comparison' - , compare_diff_with_actual=metric + , compare_with_actual_by=metric , horizons_list=horizons_list )