From 7f55a1b3a0296f97abaf55294857ece7c66be918 Mon Sep 17 00:00:00 2001 From: Riza Velioglu <40141130+rizavelioglu@users.noreply.github.com> Date: Thu, 25 Jan 2024 18:02:24 +0100 Subject: [PATCH] fix/update RCNN-family docstrings (#8231) --- torchvision/models/detection/faster_rcnn.py | 3 +-- torchvision/models/detection/keypoint_rcnn.py | 3 +-- torchvision/models/detection/mask_rcnn.py | 3 +-- torchvision/models/detection/rpn.py | 1 + 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index de32f3453bd..0dc9a580ffe 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -96,8 +96,7 @@ class FasterRCNN(GeneralizedRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN - rpn_score_thresh (float): during inference, only return proposals with a classification score - greater than rpn_score_thresh + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 1ef0c1950d1..987df6603b8 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -83,8 +83,7 @@ class KeypointRCNN(FasterRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN - rpn_score_thresh (float): during inference, only return proposals with a classification score - greater than rpn_score_thresh + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 695dd4d63ec..862eee49fda 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -84,8 +84,7 @@ class MaskRCNN(FasterRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN - rpn_score_thresh (float): during inference, only return proposals with a classification score - greater than rpn_score_thresh + rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 07a8b931150..f103181e4c6 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -133,6 +133,7 @@ class RegionProposalNetwork(torch.nn.Module): contain two fields: training and testing, to allow for different values depending on training or evaluation nms_thresh (float): NMS threshold used for postprocessing the RPN proposals + score_thresh (float): only return proposals with an objectness score greater than score_thresh """