Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to onnx model #2

Open
yanqi1811 opened this issue Mar 3, 2021 · 10 comments
Open

Convert to onnx model #2

yanqi1811 opened this issue Mar 3, 2021 · 10 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@yanqi1811
Copy link

yanqi1811 commented Mar 3, 2021

Your work is very helpful for you, thank you! But when I try to convert this pytorch model to onnx file, I meet some errors. Have you tried this program? Thanks!

@lukas-blecher
Copy link
Owner

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

@yanqi1811
Copy link
Author

Thank you for your reply!

@lukas-blecher lukas-blecher added the enhancement New feature or request label Jun 28, 2021
@Root970103
Copy link

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion.
In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?

@lukas-blecher
Copy link
Owner

  1. The main problem is that the image input size can be dynamic but that doesn't play well with the tracing/scripting methods. It is not necessarily the fault of the timm module. If you have experience in this area, I'd love to hear some tips. Feel free to open a discussion
  2. Initially, I was using both encoder and decoder from the x-transformers package but the performance was not very good. I used a pure ViT at the time (6ecc3f4). timm offered some pre-built encoders with CNN backbones which increased the performance significantly.

@cgwu1999
Copy link

cgwu1999 commented Apr 5, 2022

it isn't problem of timm,the encoder part

        pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
        pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()

here are two problem, first, cat torch.zeros(1)(float) and pos_emb_ind(long), second, one of the bug of the arange .like this #1708 , and someone said he has fixed ityou can see here, and his branch has been merged , but i still have this problem with latest version pytorch.

@tuiiitendinh
Copy link

tuiiitendinh commented Nov 14, 2022

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion. In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?

Hi, I've traced the model for my deployment exercise.
First of all is the encoder
`

encoder.eval()
 img = cv2.imread('path_to_my_example_image') #image size is 464 x 112 
 dummy_img = test_transform(image=img)['image'][:1].unsqueeze(0) #shape now is [1, 1, 112, 464]

 with torch.no_grad():
     torch.onnx.export(
         encoder, 
         dummy_img, 
         f = "encoder.onnx",
         opset_version=16, 
         input_names=['input_image'], 
         output_names=['output_context'],
         dynamic_axes={
             'context': {0: 'batch_size', 1: 'channel', 2: 'height', 3: 'width'},
             'output': {0: 'batch_size', 1: 'output_context'}
             },
         export_params=True,
         verbose=True
     )

Everything is ok for the encoder. But when it comes to the decoder, here's the code:

decoder.eval()

dummy_context = encoder(dummy_img)
dummy_tgt_seq = torch.rand(1, 512)

dummy_input = {
    'context': dummy_context,
    'tgt_seq': dummy_tgt_seq
}

with torch.no_grad():
    torch.onnx.export(
        decoder,
        args = (
            dummy_input["tgt_seq"],
            dummy_input["context"]
        ),
        f = "decoder.onnx",
        opset_version=16,
        input_names=['input_seq', 'input_context'],
        output_names=['output_seq'],
        dynamic_axes={
            'input_context': {0: 'batch_size', 1: 'sequence', '2': 'encoded_context'},
            'output_seq': {0: 'batch_size', 1: 'output_seq'}
            },
        # export_params=True,
        verbose=True
    )

I've got the error like this:

Traceback (most recent call last):
  File "d:\service-ml-api-server\flask_app\Im2Tex\pix2tex\convert_onnxmodel.py", line 105, in <module>
    torch.onnx.export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\__init__.py", line 350, in export
    return utils.export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 163, in export
    _export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)

TypeError: forward() takes 2 positional arguments but 3 were given

Is my solution got something went wrong or i have to do this in another way? Thanks.

@RhettTamp
Copy link

I cant't Convert to onnx model too, have this issue been solved?

@SWHL
Copy link

SWHL commented Jul 13, 2023

I have converted the image_resizer.pth and weights.pth to onnx format successfully, and I am organizing the inference code, please pay attention to this RapidLatexOCR repo

@tranngocduvnvp
Copy link

Hi all,
I currently convert model Latex-OCR to ONNX sucessfully.
The model's encoder and decoder are converted separately.
Details about the code via Code
If it's useful for your work please ⭐ my repo

@SWHL
Copy link

SWHL commented Feb 27, 2024

@tranngocduvnvp Coincidentally, I also compiled a conversion code before, but currently there is an issue where dynamic dimensions cannot be inferred.

My code repo is ConvertLaTeXOCRToONNX.

Welcome to communicate together.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

8 participants