Skip to content

Commit

Permalink
feat(prim::NumToTensor): Implement evaluator for NumToTensor
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jun 1, 2020
1 parent 17099fa commit 60df888
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ATen/core/List.h"
#include "ATen/core/stack.h"
#include "c10/util/intrusive_ptr.h"
#include "torch/torch.h"

#include "core/conversion/evaluators/evaluators.h"

Expand All @@ -23,6 +24,11 @@ auto prim_registrations = RegisterNodeEvaluators()
}
return torch::jit::toIValue(n->output());
}
}).evaluator({
torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return at::scalar_to_tensor(args.at(&(n->output()[0])).IValue()->toScalar());
}
}).evaluator({
torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
Expand Down

0 comments on commit 60df888

Please sign in to comment.