diff --git a/projects/PointRend/point_rend/mask_head.py b/projects/PointRend/point_rend/mask_head.py index 58652e5d8c..7118b0a5b2 100644 --- a/projects/PointRend/point_rend/mask_head.py +++ b/projects/PointRend/point_rend/mask_head.py @@ -219,7 +219,7 @@ def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): roi_features, _ = point_sample_fine_grained_features( features_list, features_scales, boxes, point_coords ) - return roi_features.view(num_boxes, -1, output_size, output_size) + return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) def _forward_mask_point(self, features, mask_coarse_logits, instances): """