diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index ee76072211..01bce47d7f 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -215,7 +215,7 @@ def single_roi_extractor__forward(ctx, roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size) if num_levels == 1: assert len(rois) > 0, 'The number of rois should be positive' - if backend == Backend.TORCHSCRIPT: + if backend == Backend.TORCHSCRIPT or backend == Backend.COREML: self.roi_layers[0].use_torchvision = True return self.roi_layers[0](feats[0], rois) @@ -241,7 +241,7 @@ def single_roi_extractor__forward(ctx, inds = mask.nonzero(as_tuple=False).squeeze(1) rois_t = rois[inds] # use the roi align in torhcvision - if backend == Backend.TORCHSCRIPT: + if backend == Backend.TORCHSCRIPT or backend == Backend.COREML: self.roi_layers[i].use_torchvision = True roi_feats_t = self.roi_layers[i](feats[i], rois_t) roi_feats[inds] = roi_feats_t