Open
Description
🐞Describing the bug
Hi, I am converting a PyTorch Faster R-CNN model to CoreML and encountered data type mismatching issue, which may be related to #2440
The model I'm converting is torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2
.
The first issue was the unsupported torchvision::roi_align
operator. With the implementation from this PR, I was able to convert a single RoIAlign
layer.
However, when converting the whole Faster R-CNN model, the second input variable rois
has unexpected shape (0,1)
and dtype int32
, where it is supposed to be a (N,5)
float tensor.
Stack Trace
ERROR - converting 'torchvision::roi_align' op (located at: 'network/roi_heads/box_roi_pool/result_idx_in_level.1'):
Converting PyTorch Frontend ==> MIL Ops: 81%|████████▏ | 1374/1686 [00:00<00:00, 6381.85 ops/s]
Traceback (most recent call last):
File "./bug_report.py", line 104, in <module>
convert_faster_rcnn_model()
File "./bug_report.py", line 101, in convert_faster_rcnn_model
ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])
File "./venv/lib/python3.12/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
mlmodel = mil_convert(
^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
proto, mil_program = mil_convert_to_proto(
^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
return _perform_torch_convert(converter, debug)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
prog = converter.convert()
^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1387, in convert
convert_nodes(self.context, self.graph, early_exit=not has_states)
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 116, in convert_nodes
raise e # re-raise exception
^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 111, in convert_nodes
convert_single_node(context, node)
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 175, in convert_single_node
add_op(context, node)
File "./bug_report.py", line 46, in roi_align
x = mb.crop_resize(
^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
return cls._add_op(op_cls_to_add, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 217, in _add_op
new_op = op_cls(**kwargs)
^^^^^^^^^^^^^^^^
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 195, in __init__
self._validate_and_set_inputs(input_kv)
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 511, in _validate_and_set_inputs
self.input_spec.validate_inputs(self.name, self.op_type, input_kvs)
File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/input_type.py", line 138, in validate_inputs
raise ValueError(msg)
ValueError: In op, of type crop_resize, named crop_resize_0, the named input `roi` must have the same data type as the named input `x`. However, roi has dtype int32 whereas x has dtype fp32.
To Reproduce
import coremltools as ct
import torch
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
register_torch_op,
)
from coremltools.converters.mil.mil import Builder as mb
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2
from torchvision.ops.roi_align import RoIAlign
@register_torch_op(torch_alias=["torchvision::roi_align"])
def roi_align(context, node):
inputs = _get_inputs(context, node)
x = context[node.inputs[0]]
input_shape = x.shape # (B, h_in, w_in, C)
if len(input_shape) != 4:
raise ValueError(
'"CropResize" op: expected input rank 4, got {}'.format(x.rank)
)
const_box_info = True
if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
const_box_info = False
extrapolation_value = context[node.inputs[2]].val
# CoreML index information along with boxes
if const_box_info:
boxes = context[node.inputs[1]].val
# CoreML expects boxes/ROI in
# [N, 1, 5, 1, 1] format
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
else:
boxes = inputs[1]
boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
# Get Height and Width of crop
h_out = inputs[3]
w_out = inputs[4]
# Torch input format: [B, C, h_in, w_in]
# CoreML input format: [B, C, h_in, w_in]
# Crop Resize
x = mb.crop_resize(
x=x,
roi=boxes,
target_height=h_out.val,
target_width=w_out.val,
normalized_coordinates=True,
spatial_scale=extrapolation_value,
box_coordinate_mode="CORNERS_HEIGHT_FIRST",
sampling_mode="OFFSET_CORNERS",
)
# CoreML output format: [N, 1, C, h_out, w_out]
# Torch output format: [N, C, h_out, w_out]
x = mb.squeeze(x=x, axes=[1])
context.add(x, torch_name=node.outputs[0])
def convert_roi_align_layer():
roi_align_layer = RoIAlign(
output_size=(7, 7), spatial_scale=1.0, sampling_ratio=1, aligned=False
)
input_tensor = torch.randn((1, 3, 400, 800))
rois_stacked = torch.FloatTensor([[0, 0, 0, 10, 10], [0, 5, 5, 20, 20]])
roi_align_layer.eval()
traced_model = torch.jit.trace(roi_align_layer, (input_tensor, rois_stacked))
ct.convert(
traced_model,
inputs=[
ct.TensorType(name="Input", shape=input_tensor.shape),
ct.TensorType(name="Rois", shape=rois_stacked.shape),
],
)
def convert_faster_rcnn_model():
model = fasterrcnn_resnet50_fpn_v2(pretrained=False)
class ModelWrapper(torch.nn.Module):
def __init__(self, network: torch.nn.Module):
super().__init__()
self.network = network
def forward(self, x):
output = self.network(x)[0]
return output["boxes"], output["labels"], output["scores"]
wrapped_model = ModelWrapper(model)
input_ = torch.randn((1, 3, 400, 800))
wrapped_model.eval()
traced_model = torch.jit.trace(wrapped_model, input_)
ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])
convert_roi_align_layer()
convert_faster_rcnn_model()
System environment:
- coremltools version: 8.1
- OS (e.g. MacOS version or Linux type): MacOS 15.3.2
- Any other relevant version information (e.g. PyTorch or TensorFlow version):
- torch==2.5.1
- torchvision==0.20.1
- numpy==1.26.4