diff --git a/autodistill_grounded_sam_2/grounded_sam_2.py b/autodistill_grounded_sam_2/grounded_sam_2.py index b924f77..550f023 100644 --- a/autodistill_grounded_sam_2/grounded_sam_2.py +++ b/autodistill_grounded_sam_2/grounded_sam_2.py @@ -13,7 +13,11 @@ from autodistill.helpers import load_image from autodistill_florence_2 import Florence2 -from autodistill_grounded_sam_2.helpers import load_SAM, load_grounding_dino, combine_detections +from autodistill_grounded_sam_2.helpers import ( + combine_detections, + load_grounding_dino, + load_SAM, +) HOME = os.path.expanduser("~") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -22,16 +26,25 @@ SUPPORTED_GROUNDING_MODELS = ["Florence 2", "Grounding DINO"] + @dataclass class GroundedSAM2(DetectionBaseModel): ontology: CaptionOntology box_threshold: float text_threshold: float - def __init__(self, ontology: CaptionOntology, model = "Florence 2", grounding_dino_box_threshold = 0.35, grounding_dino_text_threshold = 0.25): + def __init__( + self, + ontology: CaptionOntology, + model="Florence 2", + grounding_dino_box_threshold=0.35, + grounding_dino_text_threshold=0.25, + ): if model not in SUPPORTED_GROUNDING_MODELS: - raise ValueError(f"Grounding model {model} is not supported. Supported models are {SUPPORTED_GROUNDING_MODELS}") - + raise ValueError( + f"Grounding model {model} is not supported. Supported models are {SUPPORTED_GROUNDING_MODELS}" + ) + self.ontology = ontology if model == "Florence 2": self.florence_2_predictor = Florence2(ontology=ontology) diff --git a/autodistill_grounded_sam_2/helpers.py b/autodistill_grounded_sam_2/helpers.py index 04e1e3d..dfa7497 100644 --- a/autodistill_grounded_sam_2/helpers.py +++ b/autodistill_grounded_sam_2/helpers.py @@ -2,17 +2,18 @@ import subprocess import sys import urllib.request -from groundingdino.util.inference import Model -import torch import numpy as np import supervision as sv +import torch +from groundingdino.util.inference import Model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("WARNING: CUDA not available. GroundingDINO will run very slowly.") + def load_grounding_dino(): AUTODISTILL_CACHE_DIR = os.path.expanduser("~/.cache/autodistill") @@ -56,6 +57,7 @@ def load_grounding_dino(): return grounding_dino_model + def load_SAM(): cur_dir = os.getcwd() @@ -63,6 +65,9 @@ def load_SAM(): SAM_CACHE_DIR = os.path.join(AUTODISTILL_CACHE_DIR, "segment_anything_2") SAM_CHECKPOINT_PATH = os.path.join(SAM_CACHE_DIR, "sam2_hiera_base_plus.pth") + SAM_REPOSITORY_NAME = "segment-anything-2" + SAM_REPOSITORY_DIR = os.path.join(SAM_CACHE_DIR, SAM_REPOSITORY_NAME) + url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt" # Create the destination directory if it doesn't exist @@ -70,7 +75,7 @@ def load_SAM(): os.chdir(SAM_CACHE_DIR) - if not os.path.isdir("~/.cache/autodistill/segment_anything_2/segment-anything-2"): + if not os.path.isdir(SAM_REPOSITORY_DIR): subprocess.run( [ "git", @@ -79,11 +84,11 @@ def load_SAM(): ] ) - os.chdir("segment-anything-2") + os.chdir(SAM_REPOSITORY_NAME) subprocess.run(["pip", "install", "-e", "."]) - sys.path.append("~/.cache/autodistill/segment_anything_2/segment-anything-2") + sys.path.append(SAM_REPOSITORY_DIR) # Download the file if it doesn't exist if not os.path.isfile(SAM_CHECKPOINT_PATH): @@ -92,17 +97,14 @@ def load_SAM(): from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor - checkpoint = "~/.cache/autodistill/segment_anything_2/sam2_hiera_base_plus.pth" - checkpoint = os.path.expanduser(checkpoint) model_cfg = "sam2_hiera_b+.yaml" - predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) + predictor = SAM2ImagePredictor(build_sam2(model_cfg, SAM_CHECKPOINT_PATH)) os.chdir(cur_dir) return predictor - def combine_detections(detections_list, overwrite_class_ids): if len(detections_list) == 0: return sv.Detections.empty() @@ -156,4 +158,3 @@ def combine_detections(detections_list, overwrite_class_ids): class_id=class_id, tracker_id=tracker_id, ) -