Skip to content

Converting Faster-RCNN from PyTorch to CoreML #2479

Open
@gizzleon

Description

@gizzleon

🐞Describing the bug

Hi, I am converting a PyTorch Faster R-CNN model to CoreML and encountered data type mismatching issue, which may be related to #2440

The model I'm converting is torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2.

The first issue was the unsupported torchvision::roi_align operator. With the implementation from this PR, I was able to convert a single RoIAlign layer.

However, when converting the whole Faster R-CNN model, the second input variable rois has unexpected shape (0,1) and dtype int32, where it is supposed to be a (N,5) float tensor.

Stack Trace

ERROR - converting 'torchvision::roi_align' op (located at: 'network/roi_heads/box_roi_pool/result_idx_in_level.1'):

Converting PyTorch Frontend ==> MIL Ops:  81%|████████▏ | 1374/1686 [00:00<00:00, 6381.85 ops/s]
Traceback (most recent call last):
  File "./bug_report.py", line 104, in <module>
    convert_faster_rcnn_model()
  File "./bug_report.py", line 101, in convert_faster_rcnn_model
    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])
  File "./venv/lib/python3.12/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
              ^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
                         ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
    return _perform_torch_convert(converter, debug)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
    prog = converter.convert()
           ^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1387, in convert
    convert_nodes(self.context, self.graph, early_exit=not has_states)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 116, in convert_nodes
    raise e     # re-raise exception
    ^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 111, in convert_nodes
    convert_single_node(context, node)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 175, in convert_single_node
    add_op(context, node)
  File "./bug_report.py", line 46, in roi_align
    x = mb.crop_resize(
        ^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 217, in _add_op
    new_op = op_cls(**kwargs)
             ^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 195, in __init__
    self._validate_and_set_inputs(input_kv)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 511, in _validate_and_set_inputs
    self.input_spec.validate_inputs(self.name, self.op_type, input_kvs)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/input_type.py", line 138, in validate_inputs
    raise ValueError(msg)
ValueError: In op, of type crop_resize, named crop_resize_0, the named input `roi` must have the same data type as the named input `x`. However, roi has dtype int32 whereas x has dtype fp32.

To Reproduce

import coremltools as ct
import torch
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
    register_torch_op,
)
from coremltools.converters.mil.mil import Builder as mb
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2

from torchvision.ops.roi_align import RoIAlign


@register_torch_op(torch_alias=["torchvision::roi_align"])
def roi_align(context, node):
    inputs = _get_inputs(context, node)

    x = context[node.inputs[0]]
    input_shape = x.shape  # (B, h_in, w_in, C)
    if len(input_shape) != 4:
        raise ValueError(
            '"CropResize" op: expected input rank 4, got {}'.format(x.rank)
        )

    const_box_info = True
    if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
        const_box_info = False

    extrapolation_value = context[node.inputs[2]].val

    # CoreML index information along with boxes
    if const_box_info:
        boxes = context[node.inputs[1]].val
        # CoreML expects boxes/ROI in
        # [N, 1, 5, 1, 1] format
        boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
    else:
        boxes = inputs[1]
        boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
    # Get Height and Width of crop
    h_out = inputs[3]
    w_out = inputs[4]

    # Torch input format: [B, C, h_in, w_in]
    # CoreML input format: [B, C, h_in, w_in]

    # Crop Resize
    x = mb.crop_resize(
        x=x,
        roi=boxes,
        target_height=h_out.val,
        target_width=w_out.val,
        normalized_coordinates=True,
        spatial_scale=extrapolation_value,
        box_coordinate_mode="CORNERS_HEIGHT_FIRST",
        sampling_mode="OFFSET_CORNERS",
    )

    # CoreML output format: [N, 1, C, h_out, w_out]
    # Torch output format: [N, C, h_out, w_out]
    x = mb.squeeze(x=x, axes=[1])

    context.add(x, torch_name=node.outputs[0])


def convert_roi_align_layer():
    roi_align_layer = RoIAlign(
        output_size=(7, 7), spatial_scale=1.0, sampling_ratio=1, aligned=False
    )

    input_tensor = torch.randn((1, 3, 400, 800))
    rois_stacked = torch.FloatTensor([[0, 0, 0, 10, 10], [0, 5, 5, 20, 20]])

    roi_align_layer.eval()

    traced_model = torch.jit.trace(roi_align_layer, (input_tensor, rois_stacked))

    ct.convert(
        traced_model,
        inputs=[
            ct.TensorType(name="Input", shape=input_tensor.shape),
            ct.TensorType(name="Rois", shape=rois_stacked.shape),
        ],
    )


def convert_faster_rcnn_model():
    model = fasterrcnn_resnet50_fpn_v2(pretrained=False)

    class ModelWrapper(torch.nn.Module):
        def __init__(self, network: torch.nn.Module):
            super().__init__()
            self.network = network

        def forward(self, x):
            output = self.network(x)[0]
            return output["boxes"], output["labels"], output["scores"]

    wrapped_model = ModelWrapper(model)

    input_ = torch.randn((1, 3, 400, 800))
    wrapped_model.eval()

    traced_model = torch.jit.trace(wrapped_model, input_)

    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])


convert_roi_align_layer()
convert_faster_rcnn_model()

System environment:

  • coremltools version: 8.1
  • OS (e.g. MacOS version or Linux type): MacOS 15.3.2
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • torch==2.5.1
    • torchvision==0.20.1
    • numpy==1.26.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions