From eb0a91c887266a329ed8ca22434174f9d964abbb Mon Sep 17 00:00:00 2001 From: Roy Velich <6681067+royvelich@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:11:30 +0300 Subject: [PATCH 1/2] Hydra initialization fix Support using existing initialization of hydra. Support loading config files from a project's local config folder --- sam2/__init__.py | 4 +++- sam2/build_sam.py | 4 +++- sam2/sam2_image_predictor.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sam2/__init__.py b/sam2/__init__.py index ff90d1042..bcbaca8a1 100644 --- a/sam2/__init__.py +++ b/sam2/__init__.py @@ -5,5 +5,7 @@ # LICENSE file in the root directory of this source tree. from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra -initialize_config_module("sam2_configs", version_base="1.2") +if not GlobalHydra().is_initialized(): + initialize_config_module("sam2_configs", version_base="1.2") diff --git a/sam2/build_sam.py b/sam2/build_sam.py index e5911d490..17f7b66ff 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -42,6 +42,7 @@ def build_sam2( def build_sam2_video_predictor( config_file, + config_dir=None, ckpt_path=None, device="cuda", mode="eval", @@ -66,7 +67,8 @@ def build_sam2_video_predictor( hydra_overrides.extend(hydra_overrides_extra) # Read config and init model - cfg = compose(config_name=config_file, overrides=hydra_overrides) + config_name = f'{config_dir}/{config_file}' if config_dir is not None else config_file + cfg = compose(config_name=config_name, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index f6f9a5a1e..c70ccef29 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -77,7 +77,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": from sam2.build_sam import build_sam2_hf sam_model = build_sam2_hf(model_id, **kwargs) - return cls(sam_model) + return sam_model @torch.no_grad() def set_image( From bc658baa1eac4daf9982530e7b72358cfee95d1b Mon Sep 17 00:00:00 2001 From: Roy Velich <6681067+royvelich@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:34:39 +0300 Subject: [PATCH 2/2] Update sam2_video_predictor.py Remove cls call --- sam2/sam2_video_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index b5a6bdf4b..562a66b71 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -119,7 +119,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) - return cls(sam_model) + return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index."""