Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vision model's input size spedified with cmd line is overrided by pretrained model config #2035

Open
waterdropw opened this issue Sep 29, 2024 · 7 comments
Labels
exporters Issue related to exporters onnx Related to the ONNX export

Comments

@waterdropw
Copy link

optimum-cli export onnx --no-dynamic-axe --batch_size 1 --sequence_length 16 --width 224 --height 224  --num_channels 3  --model google/owlv2-base-patch16-ensemble owlv2-base-patch16-ensemble-onnx

the exported onnx model input size is still 960x960, and I found the dummy input generator will use the pretrained model config in normalized_config 960 instead, but not the cmd line specified 224:

# Some vision models can take any input sizes, in this case we use the values provided as parameters.

is it a bug?

@waterdropw
Copy link
Author

waterdropw commented Sep 29, 2024

If not, how could I export onnx model with 224x224 or other size which is different from the pretrained 960x960?

@ghost

This comment was marked as off-topic.

@dacorvo dacorvo added the onnx Related to the ONNX export label Oct 8, 2024
@IlyasMoutawwakil
Copy link
Member

if I understand correctly, you want the model to not use dynamic axes and statically exported to 224x224 ?

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Oct 11, 2024

I guess I see what's happening here, so --no-dynamic-axes feature was added to allow users to export static models, and the input shapes were added to to allow users to pass input shapes when it's impossible to infer them from config, and not to force a shape.

I think this is not a bug, since the intention was to fix specific edge cases, but yes it would make sense to support the feature you're requesting here.

All generators will have to be updated with something from:

    def __init__(
        self,
        task: str,
        normalized_config: NormalizedVisionConfig,
        batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
        num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
        width: int = DEFAULT_DUMMY_SHAPES["width"],
        height: int = DEFAULT_DUMMY_SHAPES["height"],
        **kwargs,
    ):
        self.task = task

        # Some vision models can take any input sizes, in this case we use the values provided as parameters.
        if normalized_config.has_attribute("num_channels"):
            self.num_channels = normalized_config.num_channels
        else:
            self.num_channels = num_channels

to

    def __init__(
        self,
        task: str,
        normalized_config: NormalizedVisionConfig,
        **input_shapes,
    ):
        self.task = task

        if kwargs.get("num_channels", None) is not None:
            self.num_channels = kwargs.pop("num_channels")
        elif normalized_config.has_attribute("num_channels"):
            self.num_channels = normalized_config.num_channels
        else:
            self.num_channels = DEFAULT_DUMMY_SHAPES.get("num_channels")

where user input shapes take precedence over normalized config.

@echarlaix wdyt, since static export is probably something OpenVINO models offer

@IlyasMoutawwakil IlyasMoutawwakil added the exporters Issue related to exporters label Oct 11, 2024
@waterdropw
Copy link
Author

if I understand correctly, you want the model to not use dynamic axes and statically exported to 224x224 ?

@IlyasMoutawwakil Yes, I want to deploy the model on an edge/terminal device like a phone or IoT device.
It needs to be exported as a static graph, and more important is smaller input size, which is the performance bottleneck of a VLM, because the VisionEmbeddings has a big kernel conv(ex. 16x16) and too more tokens(pow(960/224, 2)) into attention calculation with big input size.

@waterdropw
Copy link
Author

I guess I see what's happening here, so --no-dynamic-axes feature was added to allow users to export static models, and the input shapes were added to to allow users to pass input shapes when it's impossible to infer them from config, and not to force a shape.

I think this is not a bug, since the intention was to fix specific edge cases, but yes it would make sense to support the feature you're requesting here.

All generators will have to be updated with something from:

    def __init__(
        self,
        task: str,
        normalized_config: NormalizedVisionConfig,
        batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
        num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
        width: int = DEFAULT_DUMMY_SHAPES["width"],
        height: int = DEFAULT_DUMMY_SHAPES["height"],
        **kwargs,
    ):
        self.task = task

        # Some vision models can take any input sizes, in this case we use the values provided as parameters.
        if normalized_config.has_attribute("num_channels"):
            self.num_channels = normalized_config.num_channels
        else:
            self.num_channels = num_channels

to

    def __init__(
        self,
        task: str,
        normalized_config: NormalizedVisionConfig,
        **input_shapes,
    ):
        self.task = task

        if kwargs.get("num_channels", None) is not None:
            self.num_channels = kwargs.pop("num_channels")
        elif normalized_config.has_attribute("num_channels"):
            self.num_channels = normalized_config.num_channels
        else:
            self.num_channels = DEFAULT_DUMMY_SHAPES.get("num_channels")

where user input shapes take precedence over normalized config.

@echarlaix wdyt, since static export is probably something OpenVINO models offer

I have dived deep into this issue and found that position_embedding need to be interpolated for any other input sizes except the pretrained, and I test the precision is acceptable for deployment.
Maybe I will push a PR for optimum to support this feature.

@IlyasMoutawwakil
Copy link
Member

@waterdropw I would love to review a PR 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
exporters Issue related to exporters onnx Related to the ONNX export
Projects
None yet
Development

No branches or pull requests

3 participants