@@ -30,44 +30,44 @@ auto aten_registrations = RegisterNodeEvaluators()
3030 // aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
3131 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3232 auto options = torch::TensorOptions ()
33- .dtype (c10::ScalarType (args.at (&( n->output ()[ 1 ] )).unwrapToInt ()))
33+ .dtype (c10::ScalarType (args.at (n->output (1 )).unwrapToInt ()))
3434 .layout (torch::kStrided )
3535 .device (torch::kCUDA );
3636
37- auto out_tensor = torch::zeros (args.at (&( n->input ()[ 0 ] )).unwrapToIntList ().vec (), options);
37+ auto out_tensor = torch::zeros (args.at (n->input (0 )).unwrapToIntList ().vec (), options);
3838 return out_tensor;
3939 }
4040 }).evaluator({
4141 c10::Symbol::fromQualString (" aten::mul" ),
4242 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
43- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
44- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
43+ auto a = args.at (n->input (0 )).unwrapToInt ();
44+ auto b = args.at (n->input (1 )).unwrapToInt ();
4545 return a * b;
4646 },
4747 EvalOptions ().validSchemas ({" aten::mul.int(int a, int b) -> (int)" })
4848 }).evaluator({
4949 c10::Symbol::fromQualString (" aten::sub" ),
5050 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
51- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
52- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
51+ auto a = args.at (n->input (0 )).unwrapToInt ();
52+ auto b = args.at (n->input (1 )).unwrapToInt ();
5353 return a - b;
5454 },
5555 EvalOptions ().validSchemas ({" aten::sub.int(int a, int b) -> (int)" })
5656 }).evaluator({
5757 c10::Symbol::fromQualString (" aten::__round_to_zero_floordiv" ),
5858 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
59- auto a = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
60- auto b = args.at (&( n->input ()[ 1 ] )).unwrapToInt ();
59+ auto a = args.at (n->input (0 )).unwrapToInt ();
60+ auto b = args.at (n->input (1 )).unwrapToInt ();
6161 return a / b;
6262 },
6363 EvalOptions ().validSchemas ({" aten::__round_to_zero_floordiv(int a, int b) -> (int)" })
6464 }).evaluator({
6565 c10::Symbol::fromQualString (" aten::slice" ),
6666 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
67- c10::List<c10::IValue> list = args.at (&( n->input ()[ 0 ] )).IValue ()->to <c10::List<c10::IValue>>();
68- int64_t start = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
69- int64_t end = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
70- int64_t step = args.at (&( n->input ()[ 0 ] )).unwrapToInt ();
67+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
68+ int64_t start = args.at (n->input (1 )).unwrapToInt ();
69+ int64_t end = args.at (n->input (2 )).unwrapToInt ();
70+ int64_t step = args.at (n->input (3 )).unwrapToInt ();
7171
7272 const int64_t list_size = list.size ();
7373
@@ -96,10 +96,38 @@ auto aten_registrations = RegisterNodeEvaluators()
9696 }).evaluator({
9797 c10::Symbol::fromQualString (" aten::len" ),
9898 [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
99- c10::List<c10::IValue> list = args.at (&( n->input ()[ 0 ] )).IValue ()->to <c10::List<c10::IValue>>();
99+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
100100 return static_cast <int64_t >(list.size ());
101101 },
102102 EvalOptions ().validSchemas ({" aten::len.t(t[] a) -> (int)" })
103+ }).evaluator({
104+ c10::Symbol::fromQualString (" aten::size" ),
105+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
106+ LOG_WARNING (" There may be undefined behavior using dynamic shape and aten::size" );
107+ auto tensor_var = args.at (n->input (0 ));
108+ if (n->inputs ().size () == 1 ) {
109+ if (tensor_var.isITensor ()) {
110+ auto tensor = tensor_var.ITensor ();
111+ return util::toVec (tensor->getDimensions ());
112+ } else {
113+ auto tensor = tensor_var.unwrapToTensor ();
114+ return tensor.sizes ();
115+ }
116+ } else {
117+ auto dim = args.at (n->input (1 )).unwrapToInt ();
118+ if (tensor_var.isITensor ()) {
119+ auto tensor = tensor_var.ITensor ();
120+ return util::toVec (tensor->getDimensions ())[dim];
121+ } else {
122+ auto tensor = tensor_var.unwrapToTensor ();
123+ return tensor.sizes ()[dim];
124+ }
125+ }
126+ },
127+ EvalOptions ().validSchemas ({
128+ " aten::size(Tensor self) -> (int[])" ,
129+ " aten::size.int(Tensor self, int dim) -> (int)"
130+ })
103131 });
104132}
105133} // namespace evaluators
0 commit comments