diff --git a/core/conversion/evaluators/BUILD b/core/conversion/evaluators/BUILD index cbc5318342..038408771d 100644 --- a/core/conversion/evaluators/BUILD +++ b/core/conversion/evaluators/BUILD @@ -16,7 +16,9 @@ cc_library( "NodeEvaluatorRegistry.cpp", "prim.cpp", "aten.cpp", - "eval_macros.h" + "eval_macros.h", + "eval_util.h", + "eval_util.cpp" ], deps = [ "//core/util:prelude", diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp new file mode 100644 index 0000000000..486585f805 --- /dev/null +++ b/core/conversion/evaluators/eval_util.cpp @@ -0,0 +1,105 @@ +#include "ATen/core/ivalue.h" +#include "ATen/core/List.h" +#include "core/util/prelude.h" +#include "ATen/core/functional.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace evaluators { + +//TODO: Switch back to PyTorch canonical implimentation +c10::optional toIValue(const torch::jit::Value* v) { + if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast()) { + return c10::nullopt; + } + const torch::jit::Node* node = v->node(); + const c10::TypePtr& type = v->type(); + if (type->isSubtypeOf(c10::TensorType::get())) { + return node->t(c10::attr::value); + } else if (type->isSubtypeOf(c10::BoolType::get())) { + return (bool)node->i(c10::attr::value); + } else if ( + type->isSubtypeOf(c10::NumberType::get()) && + node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) { + return node->i(c10::attr::value); + } else if ( + type->isSubtypeOf(c10::NumberType::get()) && + node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) { + return node->f(c10::attr::value); + } else if (type->isSubtypeOf(c10::ListType::ofInts())) { + try { + const auto& is = node->is(c10::attr::value); + return is; + } catch (const std::exception& ex) { + const auto& ival = node->ival(c10::attr::value); + return ival; + } + } else if (type->isSubtypeOf(c10::ListType::ofFloats())) { + try { + const auto& fs = node->fs(c10::attr::value); + return fs; + } catch (const std::exception& ex) { + const auto& ival = node->ival(c10::attr::value); + return ival; + } + } else if (type->isSubtypeOf(c10::ListType::ofBools())) { + const auto bs = c10::fmap(node->is(c10::attr::value)); + return bs; + } else if (type->isSubtypeOf(c10::ListType::ofTensors())) { + try { + const auto& ts = node->ts(c10::attr::value); + return ts; + } catch (const std::exception& ex) { + const auto& ival = node->ival(c10::attr::value); + return ival; + } + } else if (type->isSubtypeOf(c10::ListType::ofStrings())) { + try { + const auto& ss = node->ss(c10::attr::value); + auto vals = c10::impl::GenericList(c10::StringType::get()); + for (const auto& str : ss) { + vals.push_back(str); + } + return vals; + } catch (const std::exception& ex) { + const auto& ival = node->ival(c10::attr::value); + return ival; + } + } else if ( + type->cast() && + node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { + const auto& list = node->ival(c10::attr::value); + TRTORCH_ASSERT(list.isList(), "Is not a list"); + return list; + } else if ( + type->cast() && + node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { + const auto& dict = node->ival(c10::attr::value); + TRTORCH_ASSERT(dict.isGenericDict(), "Is not a dict"); + return dict; + } else if ( + type->cast() && + node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { + const auto& tup = node->ival(c10::attr::value); + TRTORCH_ASSERT(tup.isTuple(), "Is not a tuple"); + return tup; + } else if (type == c10::StringType::get()) { + const auto& s = node->s(c10::attr::value); + return s; + } else if (type == c10::DeviceObjType::get()) { + auto d = c10::Device(node->s(c10::attr::value)); + return d; + } else if (node->mustBeNone()) { + return torch::jit::IValue(); + } else { + std::stringstream ss; + ss << "constant literal not supported for: " << type->str(); + throw std::runtime_error(ss.str()); + } +} + +} // namespace evaluators +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h new file mode 100644 index 0000000000..1e31ddfe46 --- /dev/null +++ b/core/conversion/evaluators/eval_util.h @@ -0,0 +1,15 @@ +#pragma once + +#include "torch/csrc/jit/ir/ir.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace evaluators { + +c10::optional toIValue(const torch::jit::Value* v); + +} // namespace evaluators +} // namespace conversion +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index d40a33f4e7..385f6e0345 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -1,7 +1,7 @@ #include #include "torch/csrc/jit/ir/ir.h" -#include "torch/csrc/jit/ir/constants.h" +//#include "torch/csrc/jit/ir/constants.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" #include "ATen/core/List.h" @@ -11,6 +11,7 @@ #include "core/conversion/evaluators/evaluators.h" #include "core/conversion/evaluators/eval_macros.h" +#include "core/conversion/evaluators/eval_util.h" namespace trtorch { namespace core { @@ -25,7 +26,7 @@ auto prim_registrations = RegisterNodeEvaluators() if (n->output()->type()->kind() == at::FunctionType::Kind) { return {}; } - return torch::jit::toIValue(n->output()); + return evaluators::toIValue(n->output()); } }).evaluator({ torch::jit::prim::NumToTensor,