Skip to content

Commit

Permalink
[onnx] Import onnx constants as onnx.Constant instead of literals (
Browse files Browse the repository at this point in the history
…llvm#2831)

To handle the conversion from raw bytes to `DenseElementsAttr` we need
to handle the endianness conversion during `torch-onnx-to-torch`.
Therefore when importing `onnx.Constant` it is better to represent using
the `onnx` constant operation so that only one location requires the
endianness correction.
  • Loading branch information
rsuderman authored Jan 31, 2024
1 parent 3500523 commit 54e2587
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,14 @@ def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = N
with InsertionPoint(self._b), Location.name(iname):
value_attr = self._cc.tensor_proto_to_attr(initializer)
vtensor_type = self._cc.tensor_proto_to_type(initializer)
attrs = {
"name": StringAttr.get(f"onnx.Constant"),
"torch.onnx.value": value_attr,
}
literal_op = Operation.create(
name="torch.vtensor.literal",
name="torch.operator",
results=[vtensor_type],
attributes={"value": value_attr},
attributes=attrs,
)
self._nv_map[iname] = literal_op.result
return literal_op.result
Expand Down

0 comments on commit 54e2587

Please sign in to comment.