Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Object Detection for pretrained weights #33

Open
oke-aditya opened this issue Nov 15, 2020 · 9 comments
Open

Refactor Object Detection for pretrained weights #33

oke-aditya opened this issue Nov 15, 2020 · 9 comments
Labels
enhancement New feature or request known issue Problem that is known and need workaround. Low Priority Should be done soon

Comments

@oke-aditya
Copy link
Owner

🚀 Feature

Similar to what we did for classification, probably we should provide something for detection.

This will allow to load pretrained weights from Kitty, COCO, etc. datasets.

@oke-aditya oke-aditya added enhancement New feature or request High Priority Should be addressed ASAP. labels Nov 15, 2020
@hassiahk
Copy link
Contributor

Hello @oke-aditya.

Can I work on this?

If yes, can you give me a little more info on what needs to be done exactly? Because I am not sure what you guys did for classification.

@oke-aditya
Copy link
Owner Author

This is really tricky.
Let me explain in bit detail.

@oke-aditya
Copy link
Owner Author

This is actually something we should look in a longer run for major refactor.

For Classification Torchvision provides CNNs (backbones) trained over imagenet.
We extended this to use any backbone, trained which we used from other hub models. E.g. we can now use ssl etc.

For this I created a dictonary in pretrained folder and simply load these models from urls.

Instantiate the model with NO pretrained weights and use these as needed.

Challenges for detection

  1. For detection Torchvision provides models trained on COCO with backbones trained on imagenet for resnet50_fpn

Detection has tremendous configurations.

  1. FPNs and No FPN models. Supporting both is necassry.
  2. Various CNNs as backbones, which we already support through refactor from classifcation. Curretnly users can easily use above backbones trained on datasets such as ssl etc.

What currently the detection API works like

from quickvision.models.detection.faster_rcnn import create_fasterrcnn_backbone
frcnn_bbone1 = create_fastercnn_backbone(backbone="resnet50", fpn=False, pretrained="ssl")
frcnn_bbone2 = create_fastercnn_backbone(backbone="resnet50", fpn=False, pretrained="imagenet")

frcnn_model1 = create_vision_fastercnn(num_classes=10, frcnn_bbone1)
frcnn_model2 = create_vision_fastercnn(num_classes=10, frcnn_bbone1)

See that this creates a frcnn model without FPNs but it supports other pre trained backbones.

For FPNs we use torchvision's resnet_fpn which creates backbones only on "imagenet" with FPNs.

backbone = resnet_fpn_backbone(backbone, pretrained=True,
                                       trainable_layers=trainable_backbone_layers, **kwargs)

Now you might ask how to get a model trained over COCO ?

For that after we create the FRCNN model, we need to load the COCO weights. Copying code from torchvision

if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
                                              progress=progress)
        model.load_state_dict(state_dict)

That's how we get Resnet50_fpn over COCO.

  1. Here we do not initate pretrained backbone.
  2. We Simply load COCO weights.

@oke-aditya
Copy link
Owner Author

oke-aditya commented Nov 15, 2020

In short. We need to support the following

  • All Non FPN backbones on imagenet and other weights (through torchvision backbones)
  • Resnet FPN Based backbones through imagenet.
  • Resnet FPN Backbones on other weights.
  • COCO Based models for all Resnet FPNs

P.S. Let me start an initial refactor, It will get clear with that.

@hassiahk
Copy link
Contributor

I got the gist of what should be done and with an initial refactor it will be more clear, thanks.

@hassiahk
Copy link
Contributor

hassiahk commented Nov 15, 2020

@oke-aditya This is what I understood from your previous comment, correct me if I am wrong:

  • For Resnet FPN Backbones on other weights, you would want something like below but resnet_fpn_backbone only supports imagenet.

    frcnn_bbone = create_fastercnn_backbone(backbone="resnet50", fpn=True, pretrained="ssl")
    
  • For COCO Based models for all Resnet FPNs, you would want something like below and the code you mentioned will work in this case.

    frcnn_bbone = create_fastercnn_backbone(backbone="resnet50", fpn=True, pretrained="coco")
    

@oke-aditya
Copy link
Owner Author

Hmm, let me start an initial refactor. This refactor is little tricky.

@oke-aditya
Copy link
Owner Author

These two are super hard to support.

Resnet FPN Backbones on other weights.

  1. Resnet FPN Backbones on other weights is somehow not possible due to hardcoding by torchvision in resnet_fpn_backbone code.

COCO Based models for all Resnet FPNs

  1. COCO Based models need training and the only model which we have is resnet_50fpn. If people contribute such models then we can easily add them, by modifying the backbone code.

Then 2nd feature is quite possible, but we need training. If people can provide them then it would be great.

@oke-aditya oke-aditya added known issue Problem that is known and need workaround. Low Priority Should be done soon and removed High Priority Should be addressed ASAP. labels Nov 18, 2020
@oke-aditya
Copy link
Owner Author

The above PR, reduces this urgency by sometime. There can be better solution but we need training for most weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request known issue Problem that is known and need workaround. Low Priority Should be done soon
Projects
None yet
Development

No branches or pull requests

2 participants