-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Turn off visual encoder #68
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -52,9 +52,12 @@ def __init__( | |||||||||||||||||
**kwargs, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
self.out_layer = torch.nn.Linear( | ||||||||||||||||||
self.encoder_dim(self.feature_extractor), self.d_model | ||||||||||||||||||
) | ||||||||||||||||||
if self.model_name in ["off", "", None]: | ||||||||||||||||||
self.out_layer = torch.nn.Identity() | ||||||||||||||||||
else: | ||||||||||||||||||
self.out_layer = torch.nn.Linear( | ||||||||||||||||||
self.encoder_dim(self.feature_extractor), self.d_model | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
def select_feature_extractor( | ||||||||||||||||||
self, model_name: str, in_chans: int, backend: str, **kwargs: Any | ||||||||||||||||||
|
@@ -70,45 +73,50 @@ def select_feature_extractor( | |||||||||||||||||
Returns: | ||||||||||||||||||
a CNN encoder based on the config and backend selected. | ||||||||||||||||||
""" | ||||||||||||||||||
if "timm" in backend.lower(): | ||||||||||||||||||
feature_extractor = timm.create_model( | ||||||||||||||||||
model_name=self.model_name, | ||||||||||||||||||
in_chans=self.in_chans, | ||||||||||||||||||
num_classes=0, | ||||||||||||||||||
**kwargs, | ||||||||||||||||||
) | ||||||||||||||||||
elif "torch" in backend.lower(): | ||||||||||||||||||
if model_name.lower() == "resnet18": | ||||||||||||||||||
feature_extractor = torchvision.models.resnet18(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
elif model_name.lower() == "resnet50": | ||||||||||||||||||
feature_extractor = torchvision.models.resnet50(**kwargs) | ||||||||||||||||||
if model_name in ["", "off", None]: | ||||||||||||||||||
feature_extractor = lambda lambda tensor: torch.zeros( | ||||||||||||||||||
(tensor.shape[0], self.d_model), dtype=tensor.dtype, device=tensor.device | ||||||||||||||||||
) # turn off visual features by returning zeros | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the syntax error in the lambda function. The lambda function definition is missing a colon. - feature_extractor = lambda lambda tensor: torch.zeros(
+ feature_extractor = lambda tensor: torch.zeros( Committable suggestion
Suggested change
ToolsRuff
|
||||||||||||||||||
else: | ||||||||||||||||||
if "timm" in backend.lower(): | ||||||||||||||||||
feature_extractor = timm.create_model( | ||||||||||||||||||
model_name=self.model_name, | ||||||||||||||||||
in_chans=self.in_chans, | ||||||||||||||||||
num_classes=0, | ||||||||||||||||||
**kwargs, | ||||||||||||||||||
) | ||||||||||||||||||
elif "torch" in backend.lower(): | ||||||||||||||||||
if model_name.lower() == "resnet18": | ||||||||||||||||||
feature_extractor = torchvision.models.resnet18(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
elif model_name.lower() == "resnet50": | ||||||||||||||||||
feature_extractor = torchvision.models.resnet50(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}" | ||||||||||||||||||
) | ||||||||||||||||||
feature_extractor = torch.nn.Sequential( | ||||||||||||||||||
*list(feature_extractor.children())[:-1] | ||||||||||||||||||
) | ||||||||||||||||||
input_layer = feature_extractor[0] | ||||||||||||||||||
if in_chans != 3: | ||||||||||||||||||
feature_extractor[0] = torch.nn.Conv2d( | ||||||||||||||||||
in_channels=in_chans, | ||||||||||||||||||
out_channels=input_layer.out_channels, | ||||||||||||||||||
kernel_size=input_layer.kernel_size, | ||||||||||||||||||
stride=input_layer.stride, | ||||||||||||||||||
padding=input_layer.padding, | ||||||||||||||||||
dilation=input_layer.dilation, | ||||||||||||||||||
groups=input_layer.groups, | ||||||||||||||||||
bias=input_layer.bias, | ||||||||||||||||||
padding_mode=input_layer.padding_mode, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}" | ||||||||||||||||||
) | ||||||||||||||||||
feature_extractor = torch.nn.Sequential( | ||||||||||||||||||
*list(feature_extractor.children())[:-1] | ||||||||||||||||||
) | ||||||||||||||||||
input_layer = feature_extractor[0] | ||||||||||||||||||
if in_chans != 3: | ||||||||||||||||||
feature_extractor[0] = torch.nn.Conv2d( | ||||||||||||||||||
in_channels=in_chans, | ||||||||||||||||||
out_channels=input_layer.out_channels, | ||||||||||||||||||
kernel_size=input_layer.kernel_size, | ||||||||||||||||||
stride=input_layer.stride, | ||||||||||||||||||
padding=input_layer.padding, | ||||||||||||||||||
dilation=input_layer.dilation, | ||||||||||||||||||
groups=input_layer.groups, | ||||||||||||||||||
bias=input_layer.bias, | ||||||||||||||||||
padding_mode=input_layer.padding_mode, | ||||||||||||||||||
f"Only ['timm', 'torch'] backends are available! Found {backend}." | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
f"Only ['timm', 'torch'] backends are available! Found {backend}." | ||||||||||||||||||
) | ||||||||||||||||||
return feature_extractor | ||||||||||||||||||
|
||||||||||||||||||
def encoder_dim(self, model: torch.nn.Module) -> int: | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -50,6 +50,14 @@ def test_encoder_timm(): | |||||||||||||||||||||||||||||
input_tensor = torch.rand(b, c, h, w) | ||||||||||||||||||||||||||||||
backend = "timm" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
encoder = VisualEncoder( | ||||||||||||||||||||||||||||||
model_name="off", in_chans=c, d_model=features, backend=backend | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
output = encoder(input_tensor) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
assert output.shape == (b, features) | ||||||||||||||||||||||||||||||
assert not torch.is_nonzero.any() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the syntax error in the assertion for non-zero values. The - assert not torch.is_nonzero.any()
+ assert not output.any() Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
encoder = VisualEncoder( | ||||||||||||||||||||||||||||||
model_name="resnet18", in_chans=c, d_model=features, backend=backend | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
@@ -93,6 +101,14 @@ def test_encoder_torch(): | |||||||||||||||||||||||||||||
input_tensor = torch.rand(b, c, h, w) | ||||||||||||||||||||||||||||||
backend = "torch" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
encoder = VisualEncoder( | ||||||||||||||||||||||||||||||
model_name="off", in_chans=c, d_model=features, backend=backend | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
output = encoder(input_tensor) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
assert output.shape == (b, features) | ||||||||||||||||||||||||||||||
assert not torch.is_nonzero.any() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the syntax error in the assertion for non-zero values. The - assert not torch.is_nonzero.any()
+ assert not output.any() Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
encoder = VisualEncoder( | ||||||||||||||||||||||||||||||
model_name="resnet18", in_chans=c, d_model=features, backend=backend | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the extraneous
f
prefix in the logging message.The
f
prefix is unnecessary as there are no placeholders in the string.Committable suggestion
Tools
Ruff