44#include  " ATen/core/ivalue.h" 
55#include  " ATen/core/List.h" 
66#include  " ATen/core/stack.h" 
7+ #include  " c10/util/intrusive_ptr.h" 
78
89#include  " core/conversion/evaluators/evaluators.h" 
910
@@ -16,51 +17,65 @@ namespace {
1617auto  prim_registrations = RegisterNodeEvaluators()
1718    .evaluator({
1819        torch::jit::prim::Constant,
19-         [](const  torch::jit::Node* n, const   kwargs& args) -> c10::optional<torch::jit::IValue> {
20+         [](const  torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2021            if  (n->output ()->type ()->kind () == at::FunctionType::Kind) {
2122                return  {};
2223            }
2324            return  torch::jit::toIValue (n->output ());
2425        }
2526    }).evaluator({
2627        torch::jit::prim::ListConstruct,
27-         [](const  torch::jit::Node* n, const   kwargs& args) -> c10::optional<torch::jit::IValue> {
28+         [](const  torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2829            const  auto  num_inputs = n->inputs ().size ();
29-             c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
30-             if  (torch::jit::IntType::get () == lt->getElementType ()) {
31-                 c10::List<int64_t > list;
32-                 list.reserve (num_inputs);
33-                 for  (auto  in : n->inputs ()) {
34-                     list.emplace_back (std::move (args.at (in)->to <int64_t >()));
35-                 }
36-                 return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
37-             } else  if  (torch::jit::FloatType::get () == lt->getElementType ()) {
38-                 c10::List<double > list;
39-                 list.reserve (num_inputs);
40-                 for  (auto  in : n->inputs ()) {
41-                     list.emplace_back (std::move (args.at (in)->to <double >()));
30+             if  (constTypesOnly (args)) {
31+                 c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
32+                 if  (torch::jit::IntType::get () == lt->getElementType ()) {
33+                     c10::List<int64_t > list;
34+                     list.reserve (num_inputs);
35+                     for  (auto  in : n->inputs ()) {
36+                         list.emplace_back (std::move (args.at (in).unwrapToInt ()));
37+                     }
38+                     return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
39+                 } else  if  (torch::jit::FloatType::get () == lt->getElementType ()) {
40+                     c10::List<double > list;
41+                     list.reserve (num_inputs);
42+                     for  (auto  in : n->inputs ()) {
43+                         list.emplace_back (std::move (args.at (in).unwrapToDouble ()));
44+                     }
45+                     return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
46+                 } else  if  (lt->getElementType () == torch::jit::BoolType::get ()) {
47+                     c10::List<bool > list;
48+                     list.reserve (num_inputs);
49+                     for  (auto  in : n->inputs ()) {
50+                         list.emplace_back (std::move (args.at (in).unwrapToBool ()));
51+                     }
52+                     return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
53+                 } else  if  (lt->getElementType ()->isSubtypeOf (torch::jit::TensorType::get ())) {
54+                     c10::List<at::Tensor> list;
55+                     list.reserve (num_inputs);
56+                     for  (auto  in : n->inputs ()) {
57+                         if  (args.at (in).isIValue ()) {
58+                             list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
59+                         }
60+                     }
61+                     return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
62+                 } else  {
63+                     c10::TypePtr elementType = lt->getElementType ();
64+                     auto  list = c10::impl::GenericList (elementType);
65+                     list.reserve (num_inputs);
66+                     for  (auto  in : n->inputs ()) {
67+                         list.emplace_back (std::move (*(args.at (in).IValue ())));
68+                     }
69+                     return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
4270                }
43-                 return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
44-             } else  if  (lt->getElementType () == torch::jit::BoolType::get ()) {
45-                 c10::List<bool > list;
46-                 list.reserve (num_inputs);
47-                 for  (auto  in : n->inputs ()) {
48-                     list.emplace_back (std::move (args.at (in)->to <bool >()));
49-                 }
50-                 return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
51-             } else  if  (lt->getElementType ()->isSubtypeOf (torch::jit::TensorType::get ())) {
52-                 c10::List<at::Tensor> list;
53-                 list.reserve (num_inputs);
54-                 for  (auto  in : n->inputs ()) {
55-                     list.emplace_back (std::move (args.at (in)->toTensor ()));
56-                 }
57-                 return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
5871            } else  {
72+                 c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
5973                c10::TypePtr elementType = lt->getElementType ();
6074                auto  list = c10::impl::GenericList (elementType);
6175                list.reserve (num_inputs);
6276                for  (auto  in : n->inputs ()) {
63-                     list.emplace_back (std::move (*(args.at (in))));
77+                     auto  x = torch::make_custom_class<TensorContainer>(reinterpret_cast <int64_t >(args.at (in).ITensor ()));
78+                     list.emplace_back (std::move (x));
6479                }
6580                return  c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
6681            }
0 commit comments