Skip to content

Commit d351717

Browse files
committed
feat(//core/conversion/evaluators): adding support for common evaluation
ops Adds support for: - prim::shape - aten::neg - aten::add - aten::__getitem__ - aten::append Fixes: - prim::min Removes: - prim::Loop Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 0014b84 commit d351717

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ auto aten_registrations = RegisterNodeEvaluators()
3737
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
3838
return out_tensor;
3939
}
40+
}).evaluator({
41+
c10::Symbol::fromQualString("aten::add"),
42+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
43+
auto a = args.at(n->input(0)).unwrapToInt();
44+
auto b = args.at(n->input(1)).unwrapToInt();
45+
return a + b;
46+
},
47+
EvalOptions().validSchemas({"aten::add.int(int a, int b) -> (int)"})
4048
}).evaluator({
4149
c10::Symbol::fromQualString("aten::mul"),
4250
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -128,6 +136,42 @@ auto aten_registrations = RegisterNodeEvaluators()
128136
"aten::size(Tensor self) -> (int[])",
129137
"aten::size.int(Tensor self, int dim) -> (int)"
130138
})
139+
}).evaluator({
140+
c10::Symbol::fromQualString("aten::__getitem__"),
141+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
142+
auto list = args.at(n->input(0)).unwrapToIntList();
143+
auto idx = args.at(n->input(1)).unwrapToInt();
144+
145+
const int64_t list_size = list.size();
146+
const int64_t normalized_idx = normalizeIndex(idx, list_size);
147+
TRTORCH_CHECK(normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)");
148+
return list.get(normalized_idx);
149+
},
150+
EvalOptions().validSchemas({
151+
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",
152+
})
153+
}).evaluator({
154+
c10::Symbol::fromQualString("aten::append"),
155+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
156+
auto list = args.at(n->input(0)).unwrapToIntList();
157+
auto el = args.at(n->input(1)).unwrapToInt();
158+
159+
list.push_back(std::move(el));
160+
return list;
161+
},
162+
EvalOptions().validSchemas({
163+
"aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))",
164+
})
165+
}).evaluator({
166+
c10::Symbol::fromQualString("aten::neg"),
167+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
168+
auto el = args.at(n->input(1)).unwrapToInt();
169+
170+
return el * -1;
171+
},
172+
EvalOptions().validSchemas({
173+
"aten::neg.int(int a) -> (int)",
174+
})
131175
});
132176
}
133177
} // namespace evaluators

core/conversion/evaluators/prim.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,6 @@ auto prim_registrations = RegisterNodeEvaluators()
9494
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
9595
}
9696
}
97-
}).evaluator({
98-
torch::jit::prim::Loop,
99-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
100-
std::cout << *n << std::endl;
101-
102-
return {};
103-
},
104-
EvalOptions().blacklistOutputTypes({c10::TensorType::get()})
10597
}).evaluator({
10698
c10::Symbol::fromQualString("prim::min"),
10799
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -110,13 +102,29 @@ auto prim_registrations = RegisterNodeEvaluators()
110102

111103
for (size_t i = 0; i < a.size(); i++) {
112104
if (a[i] < min) {
113-
min = i;
105+
min = a[i];
114106
}
115107
}
116108

117109
return min;
118110
},
119111
EvalOptions().validSchemas({"prim::min.self_int(int[] self) -> (int)"})
112+
}).evaluator({
113+
c10::Symbol::fromQualString("prim::shape"),
114+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
115+
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
116+
auto tensor_var = args.at(n->input(0));
117+
if (tensor_var.isITensor()) {
118+
auto tensor = tensor_var.ITensor();
119+
return util::toVec(tensor->getDimensions());
120+
} else {
121+
auto tensor = tensor_var.unwrapToTensor();
122+
return tensor.sizes();
123+
}
124+
},
125+
EvalOptions().validSchemas({
126+
"prim::shape(Tensor a) -> (int[])"
127+
})
120128
});
121129
}
122130
} // namespace evaluators

0 commit comments

Comments
 (0)