From 64461490e4f39e5b8cee78fd75fac3b3ea471e04 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 30 Apr 2024 17:28:15 +0800 Subject: [PATCH] Fix error in gene_random_walk (#375) train_X_ori and val_X_ori didn't get standardized, fix in this PR --- pypots/data/generating.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/pypots/data/generating.py b/pypots/data/generating.py index 1330128d..5d452374 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -274,8 +274,6 @@ def gene_random_walk( # create random missing values train_X_ori = train_X train_X = mcar(train_X, missing_rate) - val_X_ori = val_X - val_X = mcar(val_X, missing_rate) # test set is left to mask after normalization train_X = train_X.reshape(-1, n_features) @@ -305,18 +303,21 @@ def gene_random_walk( if missing_rate > 0: # mask values in the test set as ground truth - test_X_ori = test_X - test_X = mcar(test_X, missing_rate) - - data["train_X"] = train_X + train_X_ori = scaler.transform(train_X_ori.reshape(-1, n_features)).reshape( + -1, n_steps, n_features + ) data["train_X_ori"] = train_X_ori + + val_X_ori = val_X + val_X = mcar(val_X, missing_rate) data["val_X"] = val_X data["val_X_ori"] = val_X_ori - # test_X is for model input + test_X_ori = test_X + test_X = mcar(test_X, missing_rate) data["test_X"] = test_X - data["test_X_ori"] = test_X_ori - data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X) + data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc + data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X) return data @@ -421,7 +422,7 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): # test_X is for model input data["test_X"] = test_X # test_X_ori is for error calc, not for model input, hence mustn't have NaNs - data["test_X_ori"] = np.nan_to_num(test_X_ori) - data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X) + data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc + data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X) return data