Skip to content

Commit

Permalink
fix(//core/conversion/evaluators): Change how schemas are handled
Browse files Browse the repository at this point in the history
in aten::range evaluator

Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Jan 31, 2022
1 parent 3b1ce7c commit 20e5d41
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,22 +620,19 @@ auto aten_registrations TORCHTRT_UNUSED =
{"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"})})
.evaluator({c10::Symbol::fromQualString("aten::arange"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
int input_size = n->inputs().size();
int scalar_count = 0;
for (int i = 0; i < input_size; i++) {
if (args.at(n->input(i)).IValue()->isScalar()) {
scalar_count += 1;
}
}
if (scalar_count == 1) {
auto schema = n->maybeSchema();
TORCHTRT_CHECK(schema, "Unable to get schema for node: " << *n);
auto name = schema->operator_name();

if (c10::toString(name) == "aten::arange") {
if (args.at(n->input(0)).IValue()->isInt()) {
int end_scalar = args.at(n->input(0)).unwrapToInt();
return torch::arange(end_scalar);
} else if (args.at(n->input(0)).IValue()->isDouble()) {
float end_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
return torch::arange(end_scalar);
}
} else if (scalar_count == 2) {
} else if (c10::toString(name) == "aten::arange.start") {
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) {
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
Expand All @@ -645,7 +642,7 @@ auto aten_registrations TORCHTRT_UNUSED =
int end_scalar = args.at(n->input(1)).unwrapToInt();
return torch::arange(start_scalar, end_scalar);
}
} else if (scalar_count == 3) {
} else if (c10::toString(name) == "aten::arange.start_step") {
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() ||
args.at(n->input(2)).IValue()->isDouble()) {
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
Expand All @@ -659,8 +656,7 @@ auto aten_registrations TORCHTRT_UNUSED =
return torch::arange(start_scalar, end_scalar, step_scalar);
}
} else {
TORCHTRT_THROW_ERROR(
"Invalid input argument size for aten::arange, input argument size: " << input_size);
TORCHTRT_THROW_ERROR("Unsupported aten::arange variant: " << name);
}
return {};
},
Expand Down

0 comments on commit 20e5d41

Please sign in to comment.