Skip to content

Commit

Permalink
nearest nbh argument corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi0395 committed Nov 3, 2023
1 parent af37f26 commit c1528e1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion py/redrock/archetypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def nearest_neighbour_model(self, target,weights,flux,wflux,dwave,z, n_nearest,
nbasis = tdata[hs].shape[2]
if per_camera:
#Batch placeholder that right now loops over each arch but will be GPU accelerated
(zzchi2, zzcoeff) = per_camera_coeff_with_least_square_batch(spectra, tdata, nleg, method='bvls', n_nbh=1, use_gpu=use_gpu, prior=prior)
(zzchi2, zzcoeff) = per_camera_coeff_with_least_square_batch(spectra, tdata, nleg, method='bvls', n_nbh=n_nearest, use_gpu=use_gpu, prior=prior)
else:
#Use CPU mode for calc_zchi2 since small tdata
(zzchi2, zzcoeff) = calc_zchi2_batch(spectra, tdata, weights, flux, wflux, 1, nbasis, use_gpu=False)
Expand Down
4 changes: 2 additions & 2 deletions py/redrock/zscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,15 @@ def per_camera_coeff_with_least_square_batch(spectra, tdata, nleg, narch, method
### PLACEHOLDER for algorithm that will be GPU accelerated
ncam = 3 # number of cameras in DESI: b, r, z
zzchi2 = np.zeros(narch, dtype=np.float64)
zzcoeff = np.zeros((narch, 1+ncam*(nleg)), dtype=np.float64)
zzcoeff = np.zeros((narch, n_nbh+ncam*(nleg)), dtype=np.float64)

tdata_one = dict()
for i in range(narch):
for hs in tdata:
tdata_one[hs] = tdata[hs][i,:,:]
if (use_gpu):
tdata_one[hs] = tdata_one[hs].get()
zzchi2[i], zzcoeff[i]= per_camera_coeff_with_least_square(spectra, tdata_one, nleg, method='bvls', n_nbh=1, prior=prior)
zzchi2[i], zzcoeff[i]= per_camera_coeff_with_least_square(spectra, tdata_one, nleg, method='bvls', n_nbh=n_nbh, prior=prior)
return zzchi2, zzcoeff

def batch_dot_product_sparse(spectra, tdata, nz, use_gpu):
Expand Down

0 comments on commit c1528e1

Please sign in to comment.