diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 080046bfd990..4ecac009041d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1007,6 +1007,13 @@ def _impl(inputs, input_types): return arr return _impl + +def _tensortonum(): + def _impl(inputs, input_types): + return inputs[0] + return _impl + + def _view(): def _impl(inputs, input_types): data = inputs[0] @@ -1699,6 +1706,7 @@ def _get_convert_map(prelude): "aten::expand" : _expand(), "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), + "prim::ImplicitTensorToNum" : _tensortonum(), "aten::constant_pad_nd" : _pad(), "aten::permute" : _transpose(prelude), "aten::sum" : _reduce("sum"),