Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn off visual encoder #68

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class Tracker:
def __init__(
self,
window_size: int = 8,
use_vis_feats: bool = True,
overlap_thresh: float = 0.01,
mult_thresh: bool = True,
decay_time: float | None = None,
Expand All @@ -36,7 +35,6 @@ def __init__(

Args:
window_size: the size of the window used during sliding inference.
use_vis_feats: Whether or not to use visual feature extractor.
overlap_thresh: the trajectory overlap threshold to be used for assignment.
mult_thresh: Whether or not to use weight threshold.
decay_time: weight for `decay_time` postprocessing.
Expand All @@ -52,7 +50,6 @@ def __init__(
self.track_queue = TrackQueue(
window_size=window_size, max_gap=max_gap, verbose=verbose
)
self.use_vis_feats = use_vis_feats
self.overlap_thresh = overlap_thresh
self.mult_thresh = mult_thresh
self.decay_time = decay_time
Expand Down Expand Up @@ -112,36 +109,18 @@ def track(

for frame in frames:
if frame.has_instances():
if not self.use_vis_feats:
for instance in frame.instances:
instance.features = torch.zeros(1, model.d_model)
# frame["features"] = torch.randn(
# num_frame_instances, self.model.d_model
# )

# comment out to turn encoder off

# Assuming the encoder is already trained or train encoder jointly.
elif not frame.has_features():
if not frame.has_features():
with torch.no_grad():
crops = frame.get_crops()
z = model.visual_encoder(crops)

for i, z_i in enumerate(z):
frame.instances[i].features = z_i

# I feel like this chunk is unnecessary:
# reid_features = torch.cat(
# [frame["features"] for frame in instances], dim=0
# ).unsqueeze(0)

# asso_preds, pred_boxes, pred_time, embeddings = self.model(
# instances, reid_features
# )
instances_pred = self.sliding_inference(model, frames)

if not self.persistent_tracking:
logger.debug(f"Clearing Queue after tracking")
logger.debug(f"Clearing queue after tracking single batch")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the extraneous f prefix in the logging message.

The f prefix is unnecessary as there are no placeholders in the string.

-            logger.debug(f"Clearing queue after tracking single batch")
+            logger.debug("Clearing queue after tracking single batch")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.debug(f"Clearing queue after tracking single batch")
logger.debug("Clearing queue after tracking single batch")
Tools
Ruff

123-123: f-string without any placeholders

Remove extraneous f prefix

(F541)

self.track_queue.end_tracks()

return instances_pred
Expand Down
84 changes: 46 additions & 38 deletions dreem/models/visual_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ def __init__(
**kwargs,
)

self.out_layer = torch.nn.Linear(
self.encoder_dim(self.feature_extractor), self.d_model
)
if self.model_name in ["off", "", None]:
self.out_layer = torch.nn.Identity()
else:
self.out_layer = torch.nn.Linear(
self.encoder_dim(self.feature_extractor), self.d_model
)

def select_feature_extractor(
self, model_name: str, in_chans: int, backend: str, **kwargs: Any
Expand All @@ -70,45 +73,50 @@ def select_feature_extractor(
Returns:
a CNN encoder based on the config and backend selected.
"""
if "timm" in backend.lower():
feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
num_classes=0,
**kwargs,
)
elif "torch" in backend.lower():
if model_name.lower() == "resnet18":
feature_extractor = torchvision.models.resnet18(**kwargs)

elif model_name.lower() == "resnet50":
feature_extractor = torchvision.models.resnet50(**kwargs)
if model_name in ["", "off", None]:
feature_extractor = lambda lambda tensor: torch.zeros(
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device
) # turn off visual features by returning zeros
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the syntax error in the lambda function.

The lambda function definition is missing a colon.

-            feature_extractor = lambda lambda tensor: torch.zeros(
+            feature_extractor = lambda tensor: torch.zeros(
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if model_name in ["", "off", None]:
feature_extractor = lambda lambda tensor: torch.zeros(
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device
) # turn off visual features by returning zeros
if model_name in ["", "off", None]:
feature_extractor = lambda tensor: torch.zeros(
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device
) # turn off visual features by returning zeros
Tools
Ruff

77-77: SyntaxError: Expected ':', found 'lambda'

else:
if "timm" in backend.lower():
feature_extractor = timm.create_model(
model_name=self.model_name,
in_chans=self.in_chans,
num_classes=0,
**kwargs,
)
elif "torch" in backend.lower():
if model_name.lower() == "resnet18":
feature_extractor = torchvision.models.resnet18(**kwargs)

elif model_name.lower() == "resnet50":
feature_extractor = torchvision.models.resnet50(**kwargs)

else:
raise ValueError(
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
)
feature_extractor = torch.nn.Sequential(
*list(feature_extractor.children())[:-1]
)
input_layer = feature_extractor[0]
if in_chans != 3:
feature_extractor[0] = torch.nn.Conv2d(
in_channels=in_chans,
out_channels=input_layer.out_channels,
kernel_size=input_layer.kernel_size,
stride=input_layer.stride,
padding=input_layer.padding,
dilation=input_layer.dilation,
groups=input_layer.groups,
bias=input_layer.bias,
padding_mode=input_layer.padding_mode,
)

else:
raise ValueError(
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
)
feature_extractor = torch.nn.Sequential(
*list(feature_extractor.children())[:-1]
)
input_layer = feature_extractor[0]
if in_chans != 3:
feature_extractor[0] = torch.nn.Conv2d(
in_channels=in_chans,
out_channels=input_layer.out_channels,
kernel_size=input_layer.kernel_size,
stride=input_layer.stride,
padding=input_layer.padding,
dilation=input_layer.dilation,
groups=input_layer.groups,
bias=input_layer.bias,
padding_mode=input_layer.padding_mode,
f"Only ['timm', 'torch'] backends are available! Found {backend}."
)

else:
raise ValueError(
f"Only ['timm', 'torch'] backends are available! Found {backend}."
)
return feature_extractor

def encoder_dim(self, model: torch.nn.Module) -> int:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def test_encoder_timm():
input_tensor = torch.rand(b, c, h, w)
backend = "timm"

encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)

assert output.shape == (b, features)
assert not torch.is_nonzero.any()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the syntax error in the assertion for non-zero values.

The torch.is_nonzero.any() function call is incorrect.

-    assert not torch.is_nonzero.any()
+    assert not output.any()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)
assert output.shape == (b, features)
assert not torch.is_nonzero.any()
encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)
assert output.shape == (b, features)
assert not output.any()

encoder = VisualEncoder(
model_name="resnet18", in_chans=c, d_model=features, backend=backend
)
Expand Down Expand Up @@ -93,6 +101,14 @@ def test_encoder_torch():
input_tensor = torch.rand(b, c, h, w)
backend = "torch"

encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)

assert output.shape == (b, features)
assert not torch.is_nonzero.any()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the syntax error in the assertion for non-zero values.

The torch.is_nonzero.any() function call is incorrect.

-    assert not torch.is_nonzero.any()
+    assert not output.any()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)
assert output.shape == (b, features)
assert not torch.is_nonzero.any()
encoder = VisualEncoder(
model_name="off", in_chans=c, d_model=features, backend=backend
)
output = encoder(input_tensor)
assert output.shape == (b, features)
assert not output.any()

encoder = VisualEncoder(
model_name="resnet18", in_chans=c, d_model=features, backend=backend
)
Expand Down
Loading