diff --git a/src/scripts/inference_custom.py b/src/scripts/inference_custom.py index cb34e58..a2d2f38 100644 --- a/src/scripts/inference_custom.py +++ b/src/scripts/inference_custom.py @@ -28,6 +28,14 @@ from skimage.feature import canny from skimage.morphology import binary_dilation from segment_anything.utils.amg import rle_to_mask +rgb_transform = T.Compose( + [ + T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) inv_rgb_transform = T.Compose( [ T.Normalize( @@ -118,7 +126,7 @@ def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold, stabilit templates = proposal_processor(images=templates, boxes=boxes).cuda() save_image(templates, f"{template_dir}/cnos_results/templates.png", nrow=7) ref_feats = model.descriptor_model.compute_features( - templates, token_name="x_norm_clstoken" + rgb_transform(templates), token_name="x_norm_clstoken" ) logging.info(f"Ref feats: {ref_feats.shape}")