Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds OWLViT to models exportable with ONNX #18588

Merged
merged 3 commits into from
Aug 30, 2022

Conversation

unography
Copy link
Contributor

Output for tests on my local machine:

(transformers) ➜  transformers git:(owlvit_onnx) ✗ RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -v -k "owlvit" --full-trace
================================================================== test session starts ===================================================================
platform darwin -- Python 3.8.12, pytest-7.1.2, pluggy-1.0.0 -- /Users/dhruv/Documents/code/transformers/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/dhruv/Documents/code/transformers, configfile: setup.cfg
collected 410 items / 408 deselected / 2 selected                                                                                                        

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default PASSED                                                    [ 50%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default PASSED                                            [100%]

==================================================================== warnings summary ====================================================================
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/image_utils.py:223: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
    def resize(self, image, size, resample=PIL.Image.BILINEAR, default_to_square=True, max_size=None):

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/feature_extraction_owlvit.py:80: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
    resample=Image.BICUBIC,

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/modeling_owlvit.py:272: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/modeling_owlvit.py:312: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/modeling_owlvit.py:709: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    mask.fill_(torch.tensor(float("-inf")))

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/modeling_owlvit.py:280: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/src/transformers/models/owlvit/modeling_owlvit.py:289: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if attention_mask.size() != (bsz, 1, tgt_len, src_len):

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_101_owlvit_default
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_101_owlvit_default
  /Users/dhruv/Documents/code/transformers/.venv/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:4592: UserWarning: Exporting aten::index operator of advanced indexing in opset 14 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================== 2 passed, 408 deselected, 14 warnings in 44.45s =====================================================

Note: Haven't tested this on GPU yet, don't have a GPU machine with me currently.

Also, this is for the default task of OWLViT. The object-detection task isn't supported by AutoModel yet, because of which if I add that to onnx it's failing currently. Should I make the change for AutoModel as well?

cc: @chainyo

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 11, 2022

The documentation is not available anymore as the PR was closed or merged.

@regisss
Copy link
Contributor

regisss commented Aug 13, 2022

@unography That's strange that it does not work for object detection. It should actually work, DETR and YOLOS are exportable to ONNX for instance (see here). What is the error you get when trying to export the model for object detection?

@unography
Copy link
Contributor Author

@regisss I think it just needs to be defined in the config for AutoModel, for Object detection here

This is the stacktrace -

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

cls = <class 'transformers.models.auto.modeling_auto.AutoModelForObjectDetection'>
config = OwlViTConfig {
  "_commit_hash": "7cc55348dae46396474cd94bf00a542167a10f8d",
  "_name_or_path": "google/owlvit-base-pa...nsformers_version": "4.22.0.dev0",
    "typical_p": 1.0,
    "use_bfloat16": false
  },
  "vision_config_dict": null
}

kwargs = {}, trust_remote_code = False

    @classmethod
    def from_config(cls, config, **kwargs):
        trust_remote_code = kwargs.pop("trust_remote_code", False)
        if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
            if not trust_remote_code:
                raise ValueError(
                    "Loading this model requires you to execute the modeling file in that repo "
                    "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
                    "the option `trust_remote_code=True` to remove this error."
                )
            if kwargs.get("revision", None) is None:
                logger.warning(
                    "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
                    "no malicious code has been contributed in a newer revision."
                )
            class_ref = config.auto_map[cls.__name__]
            module_file, class_name = class_ref.split(".")
            model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
            return model_class._from_config(config, **kwargs)
        elif type(config) in cls._model_mapping.keys():
            model_class = _get_model_class(config, cls._model_mapping)
            return model_class._from_config(config, **kwargs)
    
>       raise ValueError(
            f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
            f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
        )
E       ValueError: Unrecognized configuration class <class 'transformers.models.owlvit.configuration_owlvit.OwlViTConfig'> for this kind of AutoModel: AutoModelForObjectDetection.
E       Model type should be one of DetrConfig, YolosConfig.

src/transformers/models/auto/auto_factory.py:412: ValueError

@alaradirik
Copy link
Contributor

alaradirik commented Aug 15, 2022

Hi @unography and @regisss! OWL-ViT is not a part of the object detection pipeline because it requires both image and search queries as input.

We are planning to add a zero-shot-object-detection pipeline for OWL-ViT (see this issue).

cc @sgugger @NielsRogge

@regisss
Copy link
Contributor

regisss commented Aug 15, 2022

Thanks for the information @alaradirik :)

@unography Let's keep only the default pipeline as you did then. I had to change one .T for .t() in modeling_owlvit.py to make the test pass, as in the PR of CLIP 😆 Could you please change this?

Copy link
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @unography!!
Thanks for this PR 😃

@regisss
Copy link
Contributor

regisss commented Aug 15, 2022

Pinging @sgugger for final approval

@unography
Copy link
Contributor Author

@regisss ya sorry i missed the .T issue, i was testing on the nightly pytorch. should be fixed now

@LysandreJik
Copy link
Member

Hey @lewtun, would you like to have a look at this and merge if it looks good to you?

@regisss
Copy link
Contributor

regisss commented Aug 24, 2022

@lewtun Can you take a quick look at this PR and merge it when you approve? 🙂

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding ONNX support for this brand new model @unography !

I've left a comment about the dynamic axes of pixel_values, but otherwise this is looking great 🔥

return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch"}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the forward pass docs, pixel_values has shape [batch_size, num_channels, height, width]

Does OWLViT typically work with dynamic shapes (e.g different sized images)? If yes, I think it would make sense to replace this with:

Suggested change
("pixel_values", {0: "batch"}),
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),

(Similar to what we did with YOLOS)

Do you agree @regisss ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it works with dynamic shapes, cc @alaradirik

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, makes sense, added the change

@@ -687,7 +687,10 @@ def forward(
last_hidden_state = self.final_layer_norm(last_hidden_state)

# take features from the end of tokens embedding (end of token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking with @alaradirik whether this change to the modeling code is OK (I think it is)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lewtun yes, that's ok

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the dynamic shapes @unography - this LGTM 🚀 !

Gently pinging @sgugger for final approval

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks to everyone involved!

@LysandreJik LysandreJik merged commit 46d0e26 into huggingface:main Aug 30, 2022
@unography unography deleted the owlvit_onnx branch August 30, 2022 14:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants