Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PFTO] Support ONNX export with list inputs #572

Open
xuzijian629 opened this issue Jul 13, 2022 · 4 comments
Open

[PFTO] Support ONNX export with list inputs #572

xuzijian629 opened this issue Jul 13, 2022 · 4 comments

Comments

@xuzijian629
Copy link
Contributor

xuzijian629 commented Jul 13, 2022

To support list inputs and models with list operators in ppe.onnx.export_testcase, we have two topics.

How to handle list inputs

Roughly, we have two choices

  • Unroll lists and assume that ppe.onnx.export always exports a onnx and testcase whose inputs are all Tensors.
  • Allow list inputs (i.e., Sequence typed inputs in ONNX).

I think we should go with the first way because torch.onnx.export does so.
Thus, torch.onnx.export actually accepts calls like torch.onnx.export(model, (list_arg,), input_names=["a", "b", "c"], ...).

This is because torch.onnx.export calls torch.jit._get_trace_graph and list inputs are automatically unrolled in it. However, this API is internal and with public torch.jit.trace, list inputs are kept as list. (For more detail, see #572 (comment)).

Since PFTO uses torch.jit.trace, one viable way is to create a wrapper model that accepts unrolled inputs.

Add support for more list operators

We have to implement custom symbolic execution for prim::ListUnpack, prim::TupleConstruct, etc... (prim::ListConstruct, which is the most used one, is already implemented).

I haven't fully understood the symbolic execution of prim::ListConstruct-like nodes in torch.onnx.export.
For future survey, I leave some memo:

  • From torch v1.12.0 symbolic functions for list ops were introduced: symbolic_opset9.py (but we can export before v1.11.0 also)
  • torch.onnx.export does onnx_peephole optimization which includes eraseListConstruct or eraseListUnpack, after _C._jit_pass_onnx(graph, operator_export_type). In my understanding, prim::ListConstruct and similar ops are replaced with onnx operators by symbolic execution. Why they remain after _C.jit_pass_onnx?

It seems to me that the handling of prim::ListConstruct-like ops has not been stable. Maybe we should wait a little bit to stabilize our implementation.

@xuzijian629
Copy link
Contributor Author

We can add custom handler of ListUnpack as ListConstruct here

handler: Dict[str, Callable] = {
"prim::Constant": handle_constant,
"prim::GetAttr": handle_getattr,
"prim::ListConstruct": handle_list_construct,
"prim::If": handle_if,
}

@xuzijian629
Copy link
Contributor Author

xuzijian629 commented Jul 14, 2022

PFTO seems to enable onnx_peephole by default, so the peephole optimization of eraseListConstruct and eraseListUnpack also runs in PFTO. However, currently, onnx_peephole optimization is placed after run_symbolic_function (in generate_onnx_node).

@xuzijian629
Copy link
Contributor Author

torch.onnx.export seems to automatically unroll list inputs.

This is done in torch.jit._get_trace_graph
https://github.com/pytorch/pytorch/blob/05ce013338b3882136eea394c37c57e29e43df1a/torch/jit/_trace.py#L95

This API is assumed to be internal and they recommend to use torch.jit.trace for public use. However, torch.jit.trace doesn't unroll list inputs

@xuzijian629
Copy link
Contributor Author

xuzijian629 commented Jul 19, 2022

It seems torch.onnx.export exports sequence inputs when model is scripted (since script modules don't know the number of tensors in lists).
So, sequence inputs is essential for scripted models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant