From f55e37a1d050df0be79942126f7c7579df67a8bb Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 23 Jul 2024 10:51:04 -0700 Subject: [PATCH 1/5] turn off visual features at model level instead of inference level --- dreem/inference/tracker.py | 25 +--------- dreem/models/visual_encoder.py | 84 +++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 61 deletions(-) diff --git a/dreem/inference/tracker.py b/dreem/inference/tracker.py index f7c29b4d..b497e86f 100644 --- a/dreem/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -21,7 +21,6 @@ class Tracker: def __init__( self, window_size: int = 8, - use_vis_feats: bool = True, overlap_thresh: float = 0.01, mult_thresh: bool = True, decay_time: float | None = None, @@ -36,7 +35,6 @@ def __init__( Args: window_size: the size of the window used during sliding inference. - use_vis_feats: Whether or not to use visual feature extractor. overlap_thresh: the trajectory overlap threshold to be used for assignment. mult_thresh: Whether or not to use weight threshold. decay_time: weight for `decay_time` postprocessing. @@ -52,7 +50,6 @@ def __init__( self.track_queue = TrackQueue( window_size=window_size, max_gap=max_gap, verbose=verbose ) - self.use_vis_feats = use_vis_feats self.overlap_thresh = overlap_thresh self.mult_thresh = mult_thresh self.decay_time = decay_time @@ -112,17 +109,7 @@ def track( for frame in frames: if frame.has_instances(): - if not self.use_vis_feats: - for instance in frame.instances: - instance.features = torch.zeros(1, model.d_model) - # frame["features"] = torch.randn( - # num_frame_instances, self.model.d_model - # ) - - # comment out to turn encoder off - - # Assuming the encoder is already trained or train encoder jointly. - elif not frame.has_features(): + if not frame.has_features(): with torch.no_grad(): crops = frame.get_crops() z = model.visual_encoder(crops) @@ -130,18 +117,10 @@ def track( for i, z_i in enumerate(z): frame.instances[i].features = z_i - # I feel like this chunk is unnecessary: - # reid_features = torch.cat( - # [frame["features"] for frame in instances], dim=0 - # ).unsqueeze(0) - - # asso_preds, pred_boxes, pred_time, embeddings = self.model( - # instances, reid_features - # ) instances_pred = self.sliding_inference(model, frames) if not self.persistent_tracking: - logger.debug(f"Clearing Queue after tracking") + logger.debug(f"Clearing queue after tracking single batch") self.track_queue.end_tracks() return instances_pred diff --git a/dreem/models/visual_encoder.py b/dreem/models/visual_encoder.py index cab5f191..18dd5edb 100644 --- a/dreem/models/visual_encoder.py +++ b/dreem/models/visual_encoder.py @@ -52,9 +52,12 @@ def __init__( **kwargs, ) - self.out_layer = torch.nn.Linear( - self.encoder_dim(self.feature_extractor), self.d_model - ) + if self.model_name in ["off", "", None]: + self.out_layer = torch.nn.Identity() + else: + self.out_layer = torch.nn.Linear( + self.encoder_dim(self.feature_extractor), self.d_model + ) def select_feature_extractor( self, model_name: str, in_chans: int, backend: str, **kwargs: Any @@ -70,45 +73,50 @@ def select_feature_extractor( Returns: a CNN encoder based on the config and backend selected. """ - if "timm" in backend.lower(): - feature_extractor = timm.create_model( - model_name=self.model_name, - in_chans=self.in_chans, - num_classes=0, - **kwargs, - ) - elif "torch" in backend.lower(): - if model_name.lower() == "resnet18": - feature_extractor = torchvision.models.resnet18(**kwargs) - - elif model_name.lower() == "resnet50": - feature_extractor = torchvision.models.resnet50(**kwargs) + if model_name in ["", "off", None]: + feature_extractor = lambda lambda tensor: torch.zeros( + (tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device + ) # turn off visual features by returning zeros + else: + if "timm" in backend.lower(): + feature_extractor = timm.create_model( + model_name=self.model_name, + in_chans=self.in_chans, + num_classes=0, + **kwargs, + ) + elif "torch" in backend.lower(): + if model_name.lower() == "resnet18": + feature_extractor = torchvision.models.resnet18(**kwargs) + + elif model_name.lower() == "resnet50": + feature_extractor = torchvision.models.resnet50(**kwargs) + + else: + raise ValueError( + f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}" + ) + feature_extractor = torch.nn.Sequential( + *list(feature_extractor.children())[:-1] + ) + input_layer = feature_extractor[0] + if in_chans != 3: + feature_extractor[0] = torch.nn.Conv2d( + in_channels=in_chans, + out_channels=input_layer.out_channels, + kernel_size=input_layer.kernel_size, + stride=input_layer.stride, + padding=input_layer.padding, + dilation=input_layer.dilation, + groups=input_layer.groups, + bias=input_layer.bias, + padding_mode=input_layer.padding_mode, + ) else: raise ValueError( - f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}" - ) - feature_extractor = torch.nn.Sequential( - *list(feature_extractor.children())[:-1] - ) - input_layer = feature_extractor[0] - if in_chans != 3: - feature_extractor[0] = torch.nn.Conv2d( - in_channels=in_chans, - out_channels=input_layer.out_channels, - kernel_size=input_layer.kernel_size, - stride=input_layer.stride, - padding=input_layer.padding, - dilation=input_layer.dilation, - groups=input_layer.groups, - bias=input_layer.bias, - padding_mode=input_layer.padding_mode, + f"Only ['timm', 'torch'] backends are available! Found {backend}." ) - - else: - raise ValueError( - f"Only ['timm', 'torch'] backends are available! Found {backend}." - ) return feature_extractor def encoder_dim(self, model: torch.nn.Module) -> int: From 3e4af0a0d8d46f9757c5b7552ef08c3128fe819e Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 23 Jul 2024 10:58:41 -0700 Subject: [PATCH 2/5] test turned off visual encoder --- tests/test_models.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 3eaf9c22..ce042799 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -50,6 +50,14 @@ def test_encoder_timm(): input_tensor = torch.rand(b, c, h, w) backend = "timm" + encoder = VisualEncoder( + model_name="off", in_chans=c, d_model=features, backend=backend + ) + output = encoder(input_tensor) + + assert output.shape == (b, features) + assert not torch.is_nonzero.any() + encoder = VisualEncoder( model_name="resnet18", in_chans=c, d_model=features, backend=backend ) @@ -93,6 +101,14 @@ def test_encoder_torch(): input_tensor = torch.rand(b, c, h, w) backend = "torch" + encoder = VisualEncoder( + model_name="off", in_chans=c, d_model=features, backend=backend + ) + output = encoder(input_tensor) + + assert output.shape == (b, features) + assert not torch.is_nonzero.any() + encoder = VisualEncoder( model_name="resnet18", in_chans=c, d_model=features, backend=backend ) From 80ccf47a25625781ea1d70e711d7a0db8a9b3677 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 23 Jul 2024 11:02:57 -0700 Subject: [PATCH 3/5] fix syntax error --- dreem/models/visual_encoder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dreem/models/visual_encoder.py b/dreem/models/visual_encoder.py index 18dd5edb..f564e087 100644 --- a/dreem/models/visual_encoder.py +++ b/dreem/models/visual_encoder.py @@ -74,9 +74,11 @@ def select_feature_extractor( a CNN encoder based on the config and backend selected. """ if model_name in ["", "off", None]: - feature_extractor = lambda lambda tensor: torch.zeros( - (tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device - ) # turn off visual features by returning zeros + feature_extractor = lambda tensor: torch.zeros( + (tensor.shape[0], self.d_model), + dtype=tensor.dtype, + device=tensor.device, + ) # turn off visual features by returning zeros else: if "timm" in backend.lower(): feature_extractor = timm.create_model( From 0e32ed38d06cab584646a07abeabf4c36394ad59 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 23 Jul 2024 11:05:49 -0700 Subject: [PATCH 4/5] fix missing method call --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index ce042799..8d5b28df 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -107,7 +107,7 @@ def test_encoder_torch(): output = encoder(input_tensor) assert output.shape == (b, features) - assert not torch.is_nonzero.any() + assert not torch.is_nonzero().any() encoder = VisualEncoder( model_name="resnet18", in_chans=c, d_model=features, backend=backend From 5b1806b2a43d249c6a12a3600137cb074338c32e Mon Sep 17 00:00:00 2001 From: aaprasad Date: Tue, 23 Jul 2024 11:21:45 -0700 Subject: [PATCH 5/5] fix tests --- dreem/training/configs/base.yaml | 1 - tests/configs/base.yaml | 1 - tests/test_config.py | 1 - tests/test_models.py | 4 ++-- 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/dreem/training/configs/base.yaml b/dreem/training/configs/base.yaml index 7779cd13..2abc7a70 100644 --- a/dreem/training/configs/base.yaml +++ b/dreem/training/configs/base.yaml @@ -47,7 +47,6 @@ scheduler: tracker: window_size: 8 - use_vis_feats: true overlap_thresh: 0.01 mult_thresh: true decay_time: null diff --git a/tests/configs/base.yaml b/tests/configs/base.yaml index 51078cee..e86fded1 100644 --- a/tests/configs/base.yaml +++ b/tests/configs/base.yaml @@ -48,7 +48,6 @@ scheduler: tracker: window_size: 8 - use_vis_feats: true overlap_thresh: 0.01 mult_thresh: true decay_time: null diff --git a/tests/test_config.py b/tests/test_config.py index 0b1c8267..7121d2d4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -70,7 +70,6 @@ def test_getters(base_config): assert set( [ "window_size", - "use_vis_feats", "overlap_thresh", "mult_thresh", "decay_time", diff --git a/tests/test_models.py b/tests/test_models.py index 8d5b28df..50fb4a81 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,7 +56,7 @@ def test_encoder_timm(): output = encoder(input_tensor) assert output.shape == (b, features) - assert not torch.is_nonzero.any() + assert not output.any() encoder = VisualEncoder( model_name="resnet18", in_chans=c, d_model=features, backend=backend @@ -107,7 +107,7 @@ def test_encoder_torch(): output = encoder(input_tensor) assert output.shape == (b, features) - assert not torch.is_nonzero().any() + assert not output.any() encoder = VisualEncoder( model_name="resnet18", in_chans=c, d_model=features, backend=backend