From 6d55f59e01d2b56d084a83c2aafeb9e2fe512b83 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Fri, 31 May 2024 14:19:30 +0300 Subject: [PATCH] Convert from zero- to one-based indices in output --- hmsc/utils/export_rds_utils.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/hmsc/utils/export_rds_utils.py b/hmsc/utils/export_rds_utils.py index f7eaf8e..531e0a2 100644 --- a/hmsc/utils/export_rds_utils.py +++ b/hmsc/utils/export_rds_utils.py @@ -33,10 +33,21 @@ def load_model_from_rds(rds_file_path): def save_chains_postList_to_rds(postList, postList_file_path, nChains, elapsedTime=-1, flag_save_eta=True): data = {} - data["list"] = convert_to_numpy(postList) - if not flag_save_eta: - for i in range(len(data["list"])): - for j in range(len(data["list"][i])): - data["list"][i][j]["Eta"] = None + output_list = convert_to_numpy(postList) + + for i in range(len(output_list)): + for j in range(len(output_list[i])): + item = output_list[i][j] + + # Convert from zero- to one-based indices + item["rhoInd"] += 1 + for k in range(len(item["AlphaInd"])): + item["AlphaInd"][k] += 1 + + # Remove eta if requested + if not flag_save_eta: + item["Eta"] = None + + data["list"] = output_list data["time"] = elapsedTime rdata.write_rds(postList_file_path, data)