From 69bb7800cd1e04ef1a0f76105d823f85556fb335 Mon Sep 17 00:00:00 2001 From: Yi Wang <37149810+yiwang12@users.noreply.github.com> Date: Tue, 25 Jul 2023 09:44:59 -0400 Subject: [PATCH] Update test_small_run.py --- tests/test_small_run.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_small_run.py b/tests/test_small_run.py index 97543b7..f5d37aa 100644 --- a/tests/test_small_run.py +++ b/tests/test_small_run.py @@ -30,19 +30,19 @@ def _run( # step 0 Data loading D, X = load_data(data_dir, n_sample) - # step 1 initialize model - fit = process_multiSample.ini_multiSample(D, n_loadings) - - # step 2 fit model - listDtrain = process_multiSample.get_listDtrain(D) - for ksample in range(0,len(D)): random.seed(10) ninduced=round(D[ksample]['X'].shape[0] * 0.35) D_tmp=D[ksample] D[ksample]["Z"]=D_tmp['X'][random.sample(range(0, D_tmp['X'].shape[0]-1), ninduced) ,:] - + + # step 1 initialize model + fit = process_multiSample.ini_multiSample(D, n_loadings) + + # step 2 fit model + listDtrain = process_multiSample.get_listDtrain(D) + (pp := (output_dir / "models" / "pp")).mkdir(parents=True, exist_ok=True) fit = training_multiSample.train_model_mNSF( fit, pp, listDtrain, D, legacy=legacy, num_epochs=epochs