diff --git a/py/rvspecfit/make_ccf.py b/py/rvspecfit/make_ccf.py index b0f199b..d182d03 100644 --- a/py/rvspecfit/make_ccf.py +++ b/py/rvspecfit/make_ccf.py @@ -188,7 +188,12 @@ def preprocess_model(logl, lammodel, model0, vsini=None, ccfconf=None): return c_model -def preprocess_model_list(lammodels, models, params, ccfconf, vsinis=None): +def preprocess_model_list(lammodels, + models, + params, + ccfconf, + vsinis=None, + nthreads=1): """Apply preprocessing to the array of models Parameters @@ -218,29 +223,35 @@ def preprocess_model_list(lammodels, models, params, ccfconf, vsinis=None): 3) spectral params 4) list of vsini """ - nthreads = 16 logl = np.linspace(ccfconf.logl0, ccfconf.logl1, ccfconf.npoints) res = [] retparams = [] if vsinis is None: vsinis = [None] vsinisList = [] - pool = mp.Pool(nthreads) + if nthreads > 1: + pool = mp.Pool(nthreads) q = [] for imodel, m0 in enumerate(models): for vsini in vsinis: retparams.append(params[imodel]) - q.append( - pool.apply_async(preprocess_model, - (logl, lammodels, m0, vsini, ccfconf))) + curargs = (logl, lammodels, m0, vsini, ccfconf) + if nthreads > 1: + q.append(pool.apply_async(preprocess_model, curargs)) + else: + q.append(preprocess_model(*curargs)) vsinisList.append(vsini) for ii, curx in enumerate(q): print('Processing : %d / %d' % (ii, len(q))) - c_model = curx.get() + if nthreads > 1: + c_model = curx.get() + else: + c_model = curx res.append(c_model) - pool.close() - pool.join() + if nthreads > 1: + pool.close() + pool.join() retparams = np.array(retparams) vsinisList = np.array(vsinisList) res = np.array(res) @@ -379,7 +390,8 @@ def ccf_executor(spec_setup, oprefix=None, every=10, vsinis=None, - revision=''): + revision='', + nthreads=1): """ Prepare the FFT transformations for the CCF @@ -429,7 +441,8 @@ def ccf_executor(spec_setup, specs, vec, ccfconf, - vsinis=vsinis) + vsinis=vsinis, + nthreads=nthreads) ffts = np.array([np.fft.rfft(x) for x in models]) fft2s = np.array([np.fft.rfft(x**2) for x in models]) savefile = (oprefix + '/' + @@ -479,6 +492,10 @@ def main(args): help='Wavelength endpoint', required=True) + parser.add_argument('--nthreads', + type=int, + help='Number of threads', + default=8) parser.add_argument('--nocontinuum', dest='nocontinuum', action='store_true') @@ -525,7 +542,8 @@ def main(args): args.oprefix, args.every, vsinis, - revision=args.revision) + revision=args.revision, + nthreads=args.nthreads) if __name__ == '__main__': diff --git a/tests/make_templ.sh b/tests/make_templ.sh index ba07fff..abdf5de 100755 --- a/tests/make_templ.sh +++ b/tests/make_templ.sh @@ -23,4 +23,4 @@ $RVS_MAKE_INTERPOL --air --setup $BNAME --lambda0 $BLAM0 --lambda1 $BLAM1 --reso $RVS_MAKE_ND --prefix ${PREFIX}/ --setup $BNAME -$RVS_MAKE_CCF --setup $BNAME --lambda0 $BLAM0 --lambda1 $BLAM1 --every 3 --vsinis $VSINIS --prefix ${PREFIX}/ --oprefix=${PREFIX} --step $BSTEP +$RVS_MAKE_CCF --nthreads 1 --setup $BNAME --lambda0 $BLAM0 --lambda1 $BLAM1 --every 3 --vsinis $VSINIS --prefix ${PREFIX}/ --oprefix=${PREFIX} --step $BSTEP