Skip to content

Commit

Permalink
Update test_small_run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Jul 25, 2023
1 parent b19eb41 commit 943c045
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tests/test_small_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import click
import pandas as pd

import random

from mNSF import process_multiSample, training_multiSample


Expand Down Expand Up @@ -32,9 +34,18 @@ def _run(
fit = process_multiSample.ini_multiSample(D, n_loadings)

# step 2 fit model
listDtrain = process_multiSample.get_listDtrain(D)

for ksample in range(0,len(D)):
random.seed(10)
ninduced=round(D[ksample]['X'].shape[0] * 0.35)
D_tmp=D[ksample]
D[ksample]["Z"]=D_tmp['X'][random.sample(range(0, D_tmp['X'].shape[0]-1), ninduced) ,:]


(pp := (output_dir / "models" / "pp")).mkdir(parents=True, exist_ok=True)
fit = training_multiSample.train_model_mNSF(
fit, pp, process_multiSample.get_listDtrain(D), D, legacy=legacy, num_epochs=epochs
fit, pp, listDtrain, D, legacy=legacy, num_epochs=epochs
)
(output_dir / "list_fit_smallData.pkl").write_bytes(pickle.dumps(fit))

Expand Down

0 comments on commit 943c045

Please sign in to comment.