Skip to content

Commit

Permalink
Merge pull request #67 from lincc-frameworks/gmerz/refactor
Browse files Browse the repository at this point in the history
add z point matching to match_objects
  • Loading branch information
grantmerz authored Nov 21, 2023
2 parents e40f2a3 + 7841e29 commit 700c3bd
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/deepdisc/inference/match_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,46 @@ def get_matched_z_pdfs_new(dataset_dicts, predictor):
return ztrues, zpreds



def get_matched_z_points_new(dataset_dicts, predictor):
"""Returns redshift point estimates for matched pairs of ground truth and detected objects test images
assuming the dataset_dicts have the image HxWxC in the 'image_shaped' field
Parameters
----------
dataset_dicts : list[dict]
The dictionary metadata for a test images
predictor: AstroPredictor
The predictor object used to make predictions on the test set
Returns
-------
z_trues: list(float)
The redshifts of matched objects in the ground truth list
z_preds: list(array(float))
The redshift pdfs of matched objects in the detections list
"""
IOUthresh = 0.5
zs = np.linspace(-1, 5.0, 200)

ztrues = []
zpreds = []

for d in dataset_dicts:
outputs = get_predictions_new(d, predictor)
matched_gts, matched_dts = get_matched_object_inds(d, outputs)

for gti, dti in zip(matched_gts, matched_dts):
ztrue = d["annotations"][int(gti)]["redshift"]
zpred = outputs["instances"].pred_redshift[int(dti)].cpu().numpy()

ztrues.append(ztrue)
zpreds.append(zpred)

return ztrues, zpreds



def run_batched_match_class(dataloader, predictor):
"""
Test function not yet implemented for batch prediction
Expand Down

0 comments on commit 700c3bd

Please sign in to comment.