diff --git a/hubconf.py b/hubconf.py index d841d47c3..20446c389 100644 --- a/hubconf.py +++ b/hubconf.py @@ -12,7 +12,7 @@ def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False): hidden_dim = 256 - backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=False, dilation=dilation) + backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation) pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True) backbone_with_pos_enc = Joiner(backbone, pos_enc) backbone_with_pos_enc.num_channels = backbone.num_channels