Skip to content

Commit

Permalink
Merge pull request huggingface#9 from Superb-AI-Suite/develop
Browse files Browse the repository at this point in the history
add num_nms parameters and set to 100
  • Loading branch information
SangbumChoi authored Apr 29, 2024
2 parents d3fa1d9 + 16cf17a commit 9d54198
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/transformers/models/deta/image_processing_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def post_process_object_detection(
threshold: float = 0.5,
target_sizes: Union[TensorType, List[Tuple]] = None,
nms_threshold: float = 0.7,
num_nms: int = 100,
):
"""
Converts the output of [`DetaForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
Expand All @@ -1111,6 +1112,8 @@ def post_process_object_detection(
(height, width) of each image in the batch. If left to None, predictions will not be resized.
nms_threshold (`float`, *optional*, defaults to 0.7):
NMS threshold.
num_nms (`int`, *optional*, defaults to 100):
Number of objects threshold.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
Expand Down Expand Up @@ -1158,7 +1161,7 @@ def post_process_object_detection(
lbls = lbls[pre_topk]

# apply NMS
keep_inds = batched_nms(box, score, lbls, nms_threshold)[:100]
keep_inds = batched_nms(box, score, lbls, nms_threshold)[:num_nms]
score = score[keep_inds]
lbls = lbls[keep_inds]
box = box[keep_inds]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,10 +1121,6 @@ def post_process_instance_segmentation(
pred_scores = scores_per_image * mask_scores_per_image
pred_classes = labels_per_image

mask_pred, pred_scores, pred_classes = remove_low_and_no_objects(
mask_pred, pred_scores, pred_classes, threshold, num_classes
)

segmentation = torch.zeros((384, 384)) - 1
if target_sizes is not None:
size = target_sizes[i] if isinstance(target_sizes[i], tuple) else target_sizes[i].cpu().tolist()
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/yolov6/image_processing_yolov6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ def post_process_object_detection(
target_sizes: Union[TensorType, List[Tuple]] = None,
nms_threshold: float = 0.65,
max_nms: int = 30000,
num_nms: int = 100,
):
"""
Converts the raw output of [`Yolov6ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
Expand All @@ -1022,6 +1023,12 @@ def post_process_object_detection(
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
nms_threshold (`float`, *optional*):
NMS score threshold to keep duplicated object detection predictions.
max_nms (`int`, *optional*):
Number of maximum output for intermediate results.
num_nms (`int`, *optional*):
Number of final output after the nms results.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
Expand Down Expand Up @@ -1067,7 +1074,7 @@ def post_process_object_detection(
score = score[indices]

# apply NMS
keep_inds = nms(box, score, nms_threshold)[:300]
keep_inds = nms(box, score, nms_threshold)[:num_nms]
score = score[keep_inds]
lbls = lbls[keep_inds]
box = box[keep_inds]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/yolov6/modeling_yolov6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
Expand Down

0 comments on commit 9d54198

Please sign in to comment.