From ccad99630b0e344cb92d32e5c8baaa7a5091f2b6 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 29 Sep 2022 11:34:52 -0700 Subject: [PATCH 01/10] fix: Implement duality support for evaluators Signed-off-by: Dheeraj Peri --- core/conversion/conversion.cpp | 2 +- core/conversion/converters/BUILD | 1 + core/conversion/converters/impl/shuffle.cpp | 7 +- .../evaluators/NodeEvaluatorRegistry.cpp | 4 +- core/conversion/evaluators/aten.cpp | 170 +- core/conversion/evaluators/eval_macros.h | 6 +- core/conversion/evaluators/evaluators.h | 6 +- core/conversion/evaluators/prim.cpp | 38 +- .../conversion/converters/test_select.cpp | 2092 +++++++++-------- 9 files changed, 1171 insertions(+), 1155 deletions(-) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 5f4b20e1b3..43f4f604c9 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -68,7 +68,7 @@ c10::optional EvaluateNode(ConversionCtx* ctx, const torch:: return {}; } } - auto eval = evaluators::EvalNode(n, eval_args); + auto eval = evaluators::EvalNode(ctx, n, eval_args); return eval; } diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 95dde838dc..284e3702a4 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -62,6 +62,7 @@ cc_library( "impl/constant_pad.cpp", "impl/conv_deconv.cpp", "impl/cumsum.cpp", + "impl/dual_ops.cpp", "impl/element_wise.cpp", "impl/expand.cpp", "impl/interpolate.cpp", diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 2df7e653ef..c4b1b1d442 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -68,10 +68,14 @@ static auto shuffle_registrations TORCHTRT_UNUSED = {"aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); + std::cout << "====1====" << std::endl; auto in_shape = util::toVec(in->getDimensions()); + std::cout << "====2====" << std::endl; std::vector new_shape; if (ctx->input_is_dynamic) { + std::cout << "====3====" << std::endl; new_shape = util::toVec(args[1].unwrapToIntList().vec()); + std::cout << "====4====" << std::endl; int nbDynamicDims = 0; for (size_t i = 0; i < new_shape.size(); i++) { if (in_shape[i] == -1) @@ -82,9 +86,10 @@ static auto shuffle_registrations TORCHTRT_UNUSED = "Resize is currently not supported when target shape contains more than one dynamic dimension"); } } else { + std::cout << "====5====" << std::endl; new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); } - + std::cout << "====6====" << std::endl; auto shuffle = ctx->net->addShuffle(*in); TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); shuffle->setReshapeDimensions(util::toDims(new_shape)); diff --git a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp index 053e08a84e..36a2ff80bf 100644 --- a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp +++ b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp @@ -114,9 +114,9 @@ std::vector getEvaluatorList() { return get_evaluator_registry().GetRegisteredEvaluatorList(); } -c10::optional EvalNode(const torch::jit::Node* n, kwargs& args) { +c10::optional EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { auto evaluator = get_evaluator_registry().GetEvaluator(n); - return evaluator(n, args); + return evaluator(ctx, n, args); } void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) { diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index b24222be26..2726db2701 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -130,7 +130,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::zeros"), // aten::zeros(int[] size, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 1 here is the dtype @@ -145,7 +145,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::ones"), // aten::ones(int[] size, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 1 here is the dtype @@ -160,7 +160,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::full"), // aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 2 here is the dtype @@ -174,7 +174,7 @@ auto aten_registrations TORCHTRT_UNUSED = }}) .evaluator( {c10::Symbol::fromQualString("aten::slice"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); int64_t start = 0; @@ -210,64 +210,64 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})}) .evaluator( {c10::Symbol::fromQualString("aten::len"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); return static_cast(list.size()); }, EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})}) - .evaluator( - {c10::Symbol::fromQualString("aten::size"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size"); - auto tensor_var = args.at(n->input(0)); - if (n->inputs().size() == 1) { - if (tensor_var.isITensor()) { - auto tensor = tensor_var.ITensor(); - return util::toVec(tensor->getDimensions()); - } else if (tensor_var.IValue()->isTensor()) { - auto tensor = tensor_var.unwrapToTensor(); - return tensor.sizes(); - } else if (tensor_var.IValue()->isCustomClass()) { - auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); - return util::toVec(tensor->getDimensions()); - } else { - TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); - } - } else { - auto dim = args.at(n->input(1)).unwrapToInt(); - if (tensor_var.isITensor()) { - auto tensor = tensor_var.ITensor(); - auto dims = util::toVec(tensor->getDimensions()); - auto nbDims = tensor->getDimensions().nbDims; - if (dim < 0) { - dim += nbDims; - } - return dims[dim]; - } else if (tensor_var.IValue()->isTensor()) { - auto tensor = tensor_var.unwrapToTensor(); - auto nbDims = tensor.sizes().size(); - if (dim < 0) { - dim += nbDims; - } - return tensor.sizes()[dim]; - } else if (tensor_var.IValue()->isCustomClass()) { - auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); - auto dims = util::toVec(tensor->getDimensions()); - auto nbDims = tensor->getDimensions().nbDims; - if (dim < 0) { - dim += nbDims; - } - return dims[dim]; - } else { - TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); - } - } - }, - EvalOptions().validSchemas( - {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) + // .evaluator( + // {c10::Symbol::fromQualString("aten::size"), + // [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size"); + // auto tensor_var = args.at(n->input(0)); + // if (n->inputs().size() == 1) { + // if (tensor_var.isITensor()) { + // auto tensor = tensor_var.ITensor(); + // return util::toVec(tensor->getDimensions()); + // } else if (tensor_var.IValue()->isTensor()) { + // auto tensor = tensor_var.unwrapToTensor(); + // return tensor.sizes(); + // } else if (tensor_var.IValue()->isCustomClass()) { + // auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); + // return util::toVec(tensor->getDimensions()); + // } else { + // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); + // } + // } else { + // auto dim = args.at(n->input(1)).unwrapToInt(); + // if (tensor_var.isITensor()) { + // auto tensor = tensor_var.ITensor(); + // auto dims = util::toVec(tensor->getDimensions()); + // auto nbDims = tensor->getDimensions().nbDims; + // if (dim < 0) { + // dim += nbDims; + // } + // return dims[dim]; + // } else if (tensor_var.IValue()->isTensor()) { + // auto tensor = tensor_var.unwrapToTensor(); + // auto nbDims = tensor.sizes().size(); + // if (dim < 0) { + // dim += nbDims; + // } + // return tensor.sizes()[dim]; + // } else if (tensor_var.IValue()->isCustomClass()) { + // auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); + // auto dims = util::toVec(tensor->getDimensions()); + // auto nbDims = tensor->getDimensions().nbDims; + // if (dim < 0) { + // dim += nbDims; + // } + // return dims[dim]; + // } else { + // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); + // } + // } + // }, + // EvalOptions().validSchemas( + // {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) .evaluator( {c10::Symbol::fromQualString("aten::__getitem__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); auto idx = args.at(n->input(1)).unwrapToInt(); @@ -282,7 +282,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::append"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); if (args.at(n->input(1)).isITensor()) { @@ -302,7 +302,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::extend"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) { c10::IValue* self_ptr = args.at(n->input(0)).IValueMut(); auto self = self_ptr->to>(); @@ -328,7 +328,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::neg"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToInt(); return el * -1; @@ -338,7 +338,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::add"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -364,7 +364,7 @@ auto aten_registrations TORCHTRT_UNUSED = "aten::add.str(str a, str b) -> (str)"})}) .evaluator( {c10::Symbol::fromQualString("aten::add_"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) { auto a = args.at(n->input(0)).IValue()->toListRef(); auto b = args.at(n->input(1)).IValue()->toListRef(); @@ -394,7 +394,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::add_.t(t[](a!) self, t[] b) -> (t[])"})}) .evaluator( {c10::Symbol::fromQualString("aten::mul"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -414,7 +414,7 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::mul.int(int a, int b) -> (int)", "aten::mul.float(float a, float b) -> (float)"})}) .evaluator( {c10::Symbol::fromQualString("aten::sub"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -436,7 +436,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::Bool"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (bool)a; @@ -453,7 +453,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::Bool.int(int a) -> (bool)", "aten::Bool.float(float b) -> (bool)"})}) .evaluator( {c10::Symbol::fromQualString("aten::Float"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (float)a; @@ -477,7 +477,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::Int"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (int)a; @@ -502,7 +502,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__not__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToBool(); return !el; @@ -512,7 +512,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__is__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); @@ -523,7 +523,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__isnot__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); @@ -534,7 +534,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::numel"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel"); auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { @@ -550,7 +550,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::dim"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); @@ -565,7 +565,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::div"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -587,7 +587,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::floordiv"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -609,7 +609,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::floor"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto el = args.at(n->input(0)).unwrapToInt(); return static_cast(std::floor(el)); @@ -629,7 +629,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::sqrt"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return std::sqrt(static_cast(a)); @@ -649,7 +649,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::warn"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto warning = args.at(n->input(0)).IValue(); LOG_WARNING("Warning from TorchScript: " << *warning); return {}; @@ -657,7 +657,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions()}) .evaluator( {c10::Symbol::fromQualString("aten::is_floating_point"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); @@ -674,7 +674,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::tensor"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto data = args.at(n->input(0)).IValue(); auto dtype = args.at(n->input(1)).IValue(); auto device = args.at(n->input(2)).IValue(); @@ -685,7 +685,7 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"})}) .evaluator( {c10::Symbol::fromQualString("aten::arange"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto schema = n->maybeSchema(); TORCHTRT_CHECK(schema, "Unable to get schema for node: " << *n); auto name = schema->operator_name(); @@ -736,7 +736,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::clone"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).isITensor()) { auto source_tensor = args.at(n->input(0)).ITensor(); auto tensor_holder = TensorContainer(); @@ -754,7 +754,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::copy_"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(1)).isITensor()) { auto source_tensor = args.at(n->input(1)).ITensor(); auto tensor_holder = TensorContainer(); @@ -773,7 +773,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::format"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { int64_t input_num = n->inputs().size(); std::vector stack; for (auto v : n->inputs()) { @@ -790,7 +790,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})}) .evaluator( {c10::Symbol::fromQualString("aten::__range_length"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto lo = args.at(n->input(0)).unwrapToInt(); auto hi = args.at(n->input(1)).unwrapToInt(); auto step = args.at(n->input(2)).unwrapToInt(); @@ -809,7 +809,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})}) .evaluator( {c10::Symbol::fromQualString("aten::__derive_index"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto idx = args.at(n->input(0)).unwrapToInt(); auto start = args.at(n->input(1)).unwrapToInt(); auto step = args.at(n->input(2)).unwrapToInt(); @@ -821,4 +821,4 @@ auto aten_registrations TORCHTRT_UNUSED = } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt \ No newline at end of file +} // namespace torch_tensorrt diff --git a/core/conversion/evaluators/eval_macros.h b/core/conversion/evaluators/eval_macros.h index 2bb126c1e9..f15046afa9 100644 --- a/core/conversion/evaluators/eval_macros.h +++ b/core/conversion/evaluators/eval_macros.h @@ -5,7 +5,7 @@ #define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ {c10::Symbol::fromQualString(node_kind), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ if (args.at(n->input(0)).IValue()->isInt()) { \ auto a = args.at(n->input(0)).unwrapToInt(); \ if (args.at(n->input(1)).IValue()->isInt()) { \ @@ -80,7 +80,7 @@ #define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ {c10::Symbol::fromQualString(node_kind), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ if (args.at(n->input(0)).IValue()->isInt()) { \ auto a = args.at(n->input(0)).unwrapToInt(); \ if (args.at(n->input(1)).IValue()->isInt()) { \ @@ -127,7 +127,7 @@ #define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ {c10::Symbol::fromQualString(node_name), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ auto a = args.at(n->input(0)).unwrapTo(); \ auto b = args.at(n->input(1)).unwrapTo(); \ return operation; \ diff --git a/core/conversion/evaluators/evaluators.h b/core/conversion/evaluators/evaluators.h index 2211fbc3e2..7c2eaf588f 100644 --- a/core/conversion/evaluators/evaluators.h +++ b/core/conversion/evaluators/evaluators.h @@ -7,6 +7,8 @@ #include "torch/csrc/jit/ir/ir.h" #include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/conversion/conversionctx/ConversionCtx.h" +#include "core/conversion/converters/converter_util.h" #include "core/conversion/var/Var.h" namespace torch_tensorrt { @@ -33,7 +35,7 @@ inline bool constTypesOnly(kwargs& args) { // to use the node itself to pull out arguments. // This means that you should iterate over node inputs vs. the args // when writing evaluators -typedef std::function(const torch::jit::Node*, kwargs&)> NodeEvaluator; +typedef std::function(ConversionCtx*, const torch::jit::Node*, kwargs&)> NodeEvaluator; struct EvalOptions { std::set blacklisted_output_types; @@ -72,7 +74,7 @@ struct EvalRegistration { : kind(_kind), evaluator(_evaluator), options(_options){}; }; -c10::optional EvalNode(const torch::jit::Node* n, kwargs& args); +c10::optional EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args); bool shouldEvalAtConversionTime(const torch::jit::Node* n); std::vector getEvaluatorList(); void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 81a7bb9991..e6fe0bab56 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -24,7 +24,7 @@ auto prim_registrations = RegisterNodeEvaluators() .evaluator( {torch::jit::prim::Constant, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->output()->type()->kind() == at::FunctionType::Kind) { return {}; } @@ -32,12 +32,12 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::NumToTensor, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar()); }}) .evaluator( {torch::jit::prim::ListUnpack, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map const torch::jit::IValue* outputs = args.at(n->input()).IValue(); auto outputVec = outputs->toList().vec(); @@ -45,7 +45,7 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::ListConstruct, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { const auto num_inputs = n->inputs().size(); if (constTypesOnly(args)) { c10::ListTypePtr lt = n->output()->type()->expect(); @@ -103,8 +103,14 @@ auto prim_registrations = if (args.at(in).IValue()->isNone()) { auto ival = torch::jit::IValue(); list.emplace_back(std::move(ival)); + } else if (args.at(in).IValue()->isInt()) { + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor(args.at(in).unwrapToInt())); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); + list.emplace_back(std::move(args.at(in).unwrapToTensor())); } } } @@ -113,7 +119,7 @@ auto prim_registrations = }}) .evaluator( {c10::Symbol::fromQualString("prim::dtype"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto input = args.at(n->input(0)); if (input.isITensor()) { auto trt_dtype = input.ITensor()->getType(); @@ -136,7 +142,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::min"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->inputs().size() == 1) { auto a = args.at(n->input(0)).unwrapToIntList(); int64_t min = std::numeric_limits::max(); @@ -198,7 +204,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::max"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->inputs().size() == 1) { auto a = args.at(n->input(0)).unwrapToIntList(); int64_t max = std::numeric_limits::min(); @@ -260,7 +266,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::shape"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape"); auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { @@ -274,7 +280,7 @@ auto prim_registrations = EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})}) .evaluator( {torch::jit::prim::TupleConstruct, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::IValue tuple = c10::ivalue::Tuple::create(); std::vector elems; for (auto in : n->inputs()) { @@ -292,7 +298,7 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::TupleIndex, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map auto tuple = args.at(n->input(0)).IValue()->toTuple(); int64_t idx = args.at(n->input(1)).IValue()->toInt(); @@ -302,24 +308,24 @@ auto prim_registrations = EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})}) .evaluator( {torch::jit::prim::TupleUnpack, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map auto output = args.at(n->input()).IValue()->toTuple(); return c10::optional(std::move(output)); }}) .evaluator( {c10::Symbol::fromQualString("prim::unchecked_cast"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return *(args.at(n->input(0)).IValue()); }}) .evaluator( {c10::Symbol::fromQualString("prim::Uninitialized"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return c10::IValue::uninitialized(); }}) .evaluator( {c10::Symbol::fromQualString("prim::RaiseException"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto exception = args.at(n->input(0)).IValue(); TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception); return {}; @@ -328,4 +334,4 @@ auto prim_registrations = } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt \ No newline at end of file +} // namespace torch_tensorrt diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 1285c24dd6..26bfcebac9 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -6,951 +6,951 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" -TEST(Converters, ATenSelectIntConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=0]() - %3 : Tensor = aten::select(%0, %2, %2) - return (%3))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSelectIntDimIsOneConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : int = prim::Constant[value=0]() - %4 : Tensor = aten::select(%0, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - // In order to check whether shape match that we don't do reshape. - // E.g. x = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}), then aten::select(x, 1, 0). We should get a tensor y with - // shape {4, 4} instead of a tensor with shape {4, 1, 4}. - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, ATenSelectIntDimNegativeConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=-2]() - %3 : int = prim::Constant[value=0]() - %4 : Tensor = aten::select(%0, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, ATenSelectIntNegIndexConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=-1]() - %4 : Tensor = aten::select(%0, %3, %2) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = torch::tensor({2, 20, 768}).to(at::kFloat).to(at::kCUDA); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=3]() - %4 : Tensor = aten::select(%0, %2, %2) - %5 : Tensor = aten::select(%4, %2, %3) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=2]() - %3 : int = prim::Constant[value=0]() - %4 : Tensor = aten::narrow(%x.1, %3, %3, %2) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {3, 2, 2, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenEmbeddingConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor, %emb_weight : Float(10, 3, strides=[3, 1])): - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=-1]() - %5 : Tensor = aten::embedding(%emb_weight, %1, %3, %2, %2) - return (%5))IR"; - - auto g = std::make_shared(); - - // Run Pytorch - torch::jit::parseIR(graph, g.get()); - auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong); - auto jit_in = at::tensor({0, 1, 2}, options_pyt); - auto embWeight = at::randn({10, 3}, {at::kCUDA}); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {embWeight}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - // Run TensorRT - auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kFloat); - auto trt_in = at::tensor({0, 1, 2}, options_trt); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenRollConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]() - %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[0, -3, -3]]() - %3 : int[] = prim::Constant[value=[1, 2, 3]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) { - const auto graph = R"IR( - graph(%1 : Tensor): - %2 : int[] = prim::Constant[value=[0, -3, -3]]() - %3 : int[] = prim::Constant[value=[1, 2, -1]]() - %4 : Tensor = aten::roll(%1, %2, %3) - return (%4))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - // Run Pytorch - auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %3 : int = prim::Constant[value=2]() - %4 : int = prim::Constant[value=4]() - %5 : int = prim::Constant[value=1]() - %6 : int = prim::Constant[value=0]() - %7 : Tensor = aten::select(%x.1, %6, %6) - %8 : Tensor = aten::select(%7, %6, %5) - %9 : Tensor = aten::slice(%8, %6, %5, %4, %3) - %10 : Tensor = aten::slice(%9, %5, %2, %2, %5) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceNegStartIndexConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : int = prim::Constant[value=9223372036854775807]() - %4 : int = prim::Constant[value=-2]() - %5 : int = prim::Constant[value=0]() - %6 : Tensor = aten::slice(%x.1, %5, %4, %3, %2) - %7 : Tensor = aten::slice(%6, %2, %5, %3, %2) - return (%7))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {6, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=3]() - %3 : int = prim::Constant[value=9223372036854775807]() - %4 : int = prim::Constant[value=2]() - %5 : int = prim::Constant[value=-3]() - %6 : int = prim::Constant[value=1]() - %7 : int = prim::Constant[value=-2]() - %8 : int = prim::Constant[value=0]() - %9 : Tensor = aten::slice(%x.1, %8, %8, %7, %6) - %10 : Tensor = aten::slice(%9, %6, %8, %5, %6) - %11 : Tensor = aten::slice(%10, %4, %8, %3, %6) - %12 : Tensor = aten::slice(%11, %2, %8, %3, %6) - return (%12))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceListConvertsCorrectly) { - const auto graph = R"IR( - graph(%x : Tensor): - %1 : NoneType = prim::Constant() - %2 : int = prim::Constant[value=2]() - %3 : int = prim::Constant[value=1]() - %4 : int = prim::Constant[value=3]() - %list : Tensor[] = aten::unbind(%x, %4) - %slice : Tensor[] = aten::slice(%list, %1, %2, %3) - %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice) - return (%out.1, %out.2))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); - - auto jit_in_x = at::clone(in_x); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x}); - - auto trt_in_x = at::clone(in_x); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=15]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=9223372036854775807]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=-15]() - %end : int = prim::Constant[value=15]() - %step : int = prim::Constant[value=2]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=0]() - %start : int = prim::Constant[value=1]() - %end : int = prim::Constant[value=-2]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %dim : int = prim::Constant[value=0]() - %start : None = prim::Constant() - %end : None = prim::Constant() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamicConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=1]() - %start : int = prim::Constant[value=3]() - %end : int = prim::Constant[value=32]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in dim 1, slice in dim 1 - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : None = prim::Constant() - %dim : int = prim::Constant[value=1]() - %start : int = prim::Constant[value=3]() - %end : int = prim::Constant[value=17]() - %step : int = prim::Constant[value=3]() - %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) - return (%9))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - // dynamic shape in batch, slice in dim 1 - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 2]]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = aten::split(%x.1, %2, %3) - %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4) - return (%x1.1, %x2.1))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int[] = prim::Constant[value=[1, 2]]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3) - %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4) - return (%5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitFixedConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2) - %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) - return (%4, %5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=2]() - %2.1 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) - %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) - return (%4, %5, %6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenSplitAndAddConvertsCorrectly) { - const auto graph = R"IR( - graph(%argument_1.1 : Tensor): - %2 : int = prim::Constant[value=2]() - %2.1 : int = prim::Constant[value=1]() - %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) - %4 : Tensor, %5 : Tensor = prim::ListUnpack(%3) - %6 : Tensor = aten::add(%4, %5, %2.1) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %44 : Device = prim::Constant[value="cuda"]() - %8 : bool = prim::Constant[value=0]() - %7 : None = prim::Constant() - %f32_dtype: int = prim::Constant[value=11]() - %1 : int = prim::Constant[value=0]() # bert.py:5:26 - %2 : int = prim::Constant[value=1]() # bert.py:5:32 - %33 : int = prim::Constant[value=2]() # bert.py:6:31 - %3 : int[] = prim::ListConstruct(%1, %1, %2) - %4 : int[] = prim::ListConstruct(%2, %2, %1) - %5 : int[][] = prim::ListConstruct(%3, %4) - %9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11 - %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 - %mask.2 : Tensor = trt::const(%mask.1) - %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 - return (%34, %mask.2))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - auto in = at::zeros({1, 2, 3}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - torch_tensorrt::core::lowering::passes::RemoveNOPs(g); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index : Tensor): - %18 : Tensor?[] = prim::ListConstruct(%index) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA}); - auto in2 = at::full({2}, 4, {at::kCUDA}); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto in2_trt = at::full({2}, 4, {options}); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor, - %index2 : Tensor): - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); - auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - auto index2_trt = index2.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - LOG_DEBUG(trt_results); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); - auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); - auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); -} - -TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor, - %index0 : Tensor, - %index1 : Tensor, - %index2 : Tensor): - %5 : NoneType = prim::Constant() - %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5) - %19 : Tensor = aten::index(%x.1, %18) - return (%19))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA}); - auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong); - auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong); - auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong); - auto index0_trt = index0.to(torch::kInt32); - auto index1_trt = index1.to(torch::kInt32); - auto index2_trt = index2.to(torch::kInt32); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} +// TEST(Converters, ATenSelectIntConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=0]() +// %3 : Tensor = aten::select(%0, %2, %2) +// return (%3))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSelectIntDimIsOneConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=1]() +// %3 : int = prim::Constant[value=0]() +// %4 : Tensor = aten::select(%0, %2, %3) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, &*g); +// +// auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// // In order to check whether shape match that we don't do reshape. +// // E.g. x = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}), then aten::select(x, 1, 0). We should get a tensor y with +// // shape {4, 4} instead of a tensor with shape {4, 1, 4}. +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +// } +// +// TEST(Converters, ATenSelectIntDimNegativeConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=-2]() +// %3 : int = prim::Constant[value=0]() +// %4 : Tensor = aten::select(%0, %2, %3) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, &*g); +// +// auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +// } +// +// TEST(Converters, ATenSelectIntNegIndexConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=0]() +// %3 : int = prim::Constant[value=-1]() +// %4 : Tensor = aten::select(%0, %3, %2) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = torch::tensor({2, 20, 768}).to(at::kFloat).to(at::kCUDA); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=0]() +// %3 : int = prim::Constant[value=3]() +// %4 : Tensor = aten::select(%0, %2, %2) +// %5 : Tensor = aten::select(%4, %2, %3) +// return (%5))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : int = prim::Constant[value=2]() +// %3 : int = prim::Constant[value=0]() +// %4 : Tensor = aten::narrow(%x.1, %3, %3, %2) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {3, 2, 2, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenEmbeddingConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%1 : Tensor, %emb_weight : Float(10, 3, strides=[3, 1])): +// %2 : bool = prim::Constant[value=0]() +// %3 : int = prim::Constant[value=-1]() +// %5 : Tensor = aten::embedding(%emb_weight, %1, %3, %2, %2) +// return (%5))IR"; +// +// auto g = std::make_shared(); +// +// // Run Pytorch +// torch::jit::parseIR(graph, g.get()); +// auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong); +// auto jit_in = at::tensor({0, 1, 2}, options_pyt); +// auto embWeight = at::randn({10, 3}, {at::kCUDA}); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {embWeight}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// // Run TensorRT +// auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kFloat); +// auto trt_in = at::tensor({0, 1, 2}, options_trt); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenRollConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%1 : Tensor): +// %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]() +// %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]() +// %4 : Tensor = aten::roll(%1, %2, %3) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// // Run Pytorch +// auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%1 : Tensor): +// %2 : int[] = prim::Constant[value=[0, -3, -3]]() +// %3 : int[] = prim::Constant[value=[1, 2, 3]]() +// %4 : Tensor = aten::roll(%1, %2, %3) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// // Run Pytorch +// auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%1 : Tensor): +// %2 : int[] = prim::Constant[value=[0, -3, -3]]() +// %3 : int[] = prim::Constant[value=[1, 2, -1]]() +// %4 : Tensor = aten::roll(%1, %2, %3) +// return (%4))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// // Run Pytorch +// auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %3 : int = prim::Constant[value=2]() +// %4 : int = prim::Constant[value=4]() +// %5 : int = prim::Constant[value=1]() +// %6 : int = prim::Constant[value=0]() +// %7 : Tensor = aten::select(%x.1, %6, %6) +// %8 : Tensor = aten::select(%7, %6, %5) +// %9 : Tensor = aten::slice(%8, %6, %5, %4, %3) +// %10 : Tensor = aten::slice(%9, %5, %2, %2, %5) +// return (%10))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {1, 3, 5, 5}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceNegStartIndexConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : int = prim::Constant[value=1]() +// %3 : int = prim::Constant[value=9223372036854775807]() +// %4 : int = prim::Constant[value=-2]() +// %5 : int = prim::Constant[value=0]() +// %6 : Tensor = aten::slice(%x.1, %5, %4, %3, %2) +// %7 : Tensor = aten::slice(%6, %2, %5, %3, %2) +// return (%7))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {6, 3}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : int = prim::Constant[value=3]() +// %3 : int = prim::Constant[value=9223372036854775807]() +// %4 : int = prim::Constant[value=2]() +// %5 : int = prim::Constant[value=-3]() +// %6 : int = prim::Constant[value=1]() +// %7 : int = prim::Constant[value=-2]() +// %8 : int = prim::Constant[value=0]() +// %9 : Tensor = aten::slice(%x.1, %8, %8, %7, %6) +// %10 : Tensor = aten::slice(%9, %6, %8, %5, %6) +// %11 : Tensor = aten::slice(%10, %4, %8, %3, %6) +// %12 : Tensor = aten::slice(%11, %2, %8, %3, %6) +// return (%12))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceListConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x : Tensor): +// %1 : NoneType = prim::Constant() +// %2 : int = prim::Constant[value=2]() +// %3 : int = prim::Constant[value=1]() +// %4 : int = prim::Constant[value=3]() +// %list : Tensor[] = aten::unbind(%x, %4) +// %slice : Tensor[] = aten::slice(%list, %1, %2, %3) +// %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice) +// return (%out.1, %out.2))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA}); +// +// auto jit_in_x = at::clone(in_x); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x}); +// +// auto trt_in_x = at::clone(in_x); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=0]() +// %start : int = prim::Constant[value=1]() +// %end : int = prim::Constant[value=15]() +// %step : int = prim::Constant[value=2]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=0]() +// %start : int = prim::Constant[value=1]() +// %end : int = prim::Constant[value=9223372036854775807]() +// %step : int = prim::Constant[value=2]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=0]() +// %start : int = prim::Constant[value=-15]() +// %end : int = prim::Constant[value=15]() +// %step : int = prim::Constant[value=2]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=0]() +// %start : int = prim::Constant[value=1]() +// %end : int = prim::Constant[value=-2]() +// %step : int = prim::Constant[value=3]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %dim : int = prim::Constant[value=0]() +// %start : None = prim::Constant() +// %end : None = prim::Constant() +// %step : int = prim::Constant[value=3]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamicConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=1]() +// %start : int = prim::Constant[value=3]() +// %end : int = prim::Constant[value=32]() +// %step : int = prim::Constant[value=3]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in dim 1, slice in dim 1 +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : None = prim::Constant() +// %dim : int = prim::Constant[value=1]() +// %start : int = prim::Constant[value=3]() +// %end : int = prim::Constant[value=17]() +// %step : int = prim::Constant[value=3]() +// %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) +// return (%9))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// // dynamic shape in batch, slice in dim 1 +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); +// auto trt = trt_results[0].reshape(jit_results[0].sizes()); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +// } +// +// TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : int[] = prim::Constant[value=[1, 2]]() +// %3 : int = prim::Constant[value=1]() +// %4 : Tensor[] = aten::split(%x.1, %2, %3) +// %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4) +// return (%x1.1, %x2.1))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%argument_1.1 : Tensor): +// %2 : int[] = prim::Constant[value=[1, 2]]() +// %3 : int = prim::Constant[value=1]() +// %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3) +// %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4) +// return (%5, %6))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenSplitFixedConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%argument_1.1 : Tensor): +// %2 : int = prim::Constant[value=1]() +// %3 : Tensor[] = aten::split(%argument_1.1, %2, %2) +// %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) +// return (%4, %5, %6))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%argument_1.1 : Tensor): +// %2 : int = prim::Constant[value=2]() +// %2.1 : int = prim::Constant[value=1]() +// %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) +// %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) +// return (%4, %5, %6))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, &*g); +// +// auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenSplitAndAddConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%argument_1.1 : Tensor): +// %2 : int = prim::Constant[value=2]() +// %2.1 : int = prim::Constant[value=1]() +// %3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1) +// %4 : Tensor, %5 : Tensor = prim::ListUnpack(%3) +// %6 : Tensor = aten::add(%4, %5, %2.1) +// return (%6))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, &*g); +// +// auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %44 : Device = prim::Constant[value="cuda"]() +// %8 : bool = prim::Constant[value=0]() +// %7 : None = prim::Constant() +// %f32_dtype: int = prim::Constant[value=11]() +// %1 : int = prim::Constant[value=0]() # bert.py:5:26 +// %2 : int = prim::Constant[value=1]() # bert.py:5:32 +// %33 : int = prim::Constant[value=2]() # bert.py:6:31 +// %3 : int[] = prim::ListConstruct(%1, %1, %2) +// %4 : int[] = prim::ListConstruct(%2, %2, %1) +// %5 : int[][] = prim::ListConstruct(%3, %4) +// %9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11 +// %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 +// %mask.2 : Tensor = trt::const(%mask.1) +// %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 +// return (%34, %mask.2))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, &*g); +// +// auto in = at::zeros({1, 2, 3}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// torch_tensorrt::core::lowering::passes::RemoveNOPs(g); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index : Tensor): +// %18 : Tensor?[] = prim::ListConstruct(%index) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA}); +// auto in2 = at::full({2}, 4, {at::kCUDA}); +// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); +// auto in2_trt = at::full({2}, 4, {options}); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt}); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index0 : Tensor, +// %index1 : Tensor, +// %index2 : Tensor): +// %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); +// auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); +// auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); +// auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); +// auto index0_trt = index0.to(torch::kInt32); +// auto index1_trt = index1.to(torch::kInt32); +// auto index2_trt = index2.to(torch::kInt32); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index0 : Tensor, +// %index1 : Tensor): +// %5 : NoneType = prim::Constant() +// %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); +// auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); +// auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); +// auto index0_trt = index0.to(torch::kInt32); +// auto index1_trt = index1.to(torch::kInt32); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); +// LOG_DEBUG(trt_results); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index0 : Tensor, +// %index1 : Tensor): +// %5 : NoneType = prim::Constant() +// %18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); +// auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); +// auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); +// auto index0_trt = index0.to(torch::kInt32); +// auto index1_trt = index1.to(torch::kInt32); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index0 : Tensor, +// %index1 : Tensor): +// %5 : NoneType = prim::Constant() +// %18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); +// auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); +// auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); +// auto index0_trt = index0.to(torch::kInt32); +// auto index1_trt = index1.to(torch::kInt32); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); +// +// ASSERT_TRUE( +// torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +// } +// +// TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor, +// %index0 : Tensor, +// %index1 : Tensor, +// %index2 : Tensor): +// %5 : NoneType = prim::Constant() +// %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5) +// %19 : Tensor = aten::index(%x.1, %18) +// return (%19))IR"; +// +// auto g = std::make_shared(); +// torch::jit::parseIR(graph, g.get()); +// +// auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA}); +// auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong); +// auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong); +// auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong); +// auto index0_trt = index0.to(torch::kInt32); +// auto index1_trt = index1.to(torch::kInt32); +// auto index2_trt = index2.to(torch::kInt32); +// +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); +// +// params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); +// +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +// } TEST(Converters, ATenUnbindConvertsCorrectly) { const auto graph = R"IR( @@ -974,107 +974,109 @@ TEST(Converters, ATenUnbindConvertsCorrectly) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=-1]() - %3 : Tensor[] = aten::unbind(%x.1, %2) - %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) - return (%o1.1, %o2.1))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(1, 10, {5, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} - -TEST(Converters, ScatterValueConvertsCorrectly) { - const auto graph = R"IR( - graph(%data : Tensor, - %index.1 : Tensor): - %value : int = prim::Constant[value=100]() - %dim : int = prim::Constant[value=1]() - %5 : NoneType = prim::Constant() - %6 : bool = prim::Constant[value=0]() - %7 : int = prim::Constant[value=4]() - %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) - %10 : Tensor = aten::scatter(%data, %dim, %index, %value) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto index = at::randint(0, 5, {2, 2}, {at::kCUDA}); - auto data = at::randn({5, 5}, {at::kCUDA}); - - auto jit_index = at::clone(index); - auto jit_data = at::clone(data); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index}); - - auto trt_index = at::clone(index); - auto trt_data = at::clone(data); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); + auto trt = trt_results[i]; //.reshape(jit_results[i].sizes()); + LOG_DEBUG("==== TRT shape: " << trt.sizes()); + LOG_DEBUG("==== JIT shape: " << jit_results[i].sizes()); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } -TEST(Converters, ScatterSrcConvertsCorrectly) { - const auto graph = R"IR( - graph(%data : Tensor, - %src : Tensor, - %index.1 : Tensor): - %dim : int = prim::Constant[value=1]() - %5 : NoneType = prim::Constant() - %6 : bool = prim::Constant[value=0]() - %7 : int = prim::Constant[value=4]() - %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) - %10 : Tensor = aten::scatter(%data, %dim, %index, %src) - return (%10))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto index = at::randint(0, 4, {2, 2}, {at::kCUDA}); - auto data = at::randn({5, 5}, {at::kCUDA}); - auto src = at::randn({2, 2}, {at::kCUDA}); - - auto jit_index = at::clone(index); - auto jit_data = at::clone(data); - auto jit_src = at::clone(src); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index}); - - auto trt_index = at::clone(index); - auto trt_data = at::clone(data); - auto trt_src = at::clone(src); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index}); - - for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); - } -} \ No newline at end of file +// TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%x.1 : Tensor): +// %2 : int = prim::Constant[value=-1]() +// %3 : Tensor[] = aten::unbind(%x.1, %2) +// %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3) +// return (%o1.1, %o2.1))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto in = at::randint(1, 10, {5, 2}, {at::kCUDA}); +// +// auto jit_in = at::clone(in); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); +// +// auto trt_in = at::clone(in); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ScatterValueConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%data : Tensor, +// %index.1 : Tensor): +// %value : int = prim::Constant[value=100]() +// %dim : int = prim::Constant[value=1]() +// %5 : NoneType = prim::Constant() +// %6 : bool = prim::Constant[value=0]() +// %7 : int = prim::Constant[value=4]() +// %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) +// %10 : Tensor = aten::scatter(%data, %dim, %index, %value) +// return (%10))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto index = at::randint(0, 5, {2, 2}, {at::kCUDA}); +// auto data = at::randn({5, 5}, {at::kCUDA}); +// +// auto jit_index = at::clone(index); +// auto jit_data = at::clone(data); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index}); +// +// auto trt_index = at::clone(index); +// auto trt_data = at::clone(data); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } +// +// TEST(Converters, ScatterSrcConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%data : Tensor, +// %src : Tensor, +// %index.1 : Tensor): +// %dim : int = prim::Constant[value=1]() +// %5 : NoneType = prim::Constant() +// %6 : bool = prim::Constant[value=0]() +// %7 : int = prim::Constant[value=4]() +// %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) +// %10 : Tensor = aten::scatter(%data, %dim, %index, %src) +// return (%10))IR"; +// +// auto g = std::make_shared(); +// +// torch::jit::parseIR(graph, g.get()); +// +// auto index = at::randint(0, 4, {2, 2}, {at::kCUDA}); +// auto data = at::randn({5, 5}, {at::kCUDA}); +// auto src = at::randn({2, 2}, {at::kCUDA}); +// +// auto jit_index = at::clone(index); +// auto jit_data = at::clone(data); +// auto jit_src = at::clone(src); +// auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); +// auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index}); +// +// auto trt_index = at::clone(index); +// auto trt_data = at::clone(data); +// auto trt_src = at::clone(src); +// auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index}); +// +// for (size_t i = 0; i < jit_results.size(); i++) { +// auto trt = trt_results[i].reshape(jit_results[i].sizes()); +// ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); +// } +// } From 04ded550f3895189c2d1a83179324f39f93c6ef3 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 24 Oct 2022 22:36:58 -0700 Subject: [PATCH 02/10] chore: play around with aten::size Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/shuffle.cpp | 5 +++-- core/conversion/var/Var.cpp | 18 ++++++++++++++++++ core/conversion/var/Var.h | 2 ++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index c4b1b1d442..52943ecb5a 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -73,8 +73,9 @@ static auto shuffle_registrations TORCHTRT_UNUSED = std::cout << "====2====" << std::endl; std::vector new_shape; if (ctx->input_is_dynamic) { - std::cout << "====3====" << std::endl; - new_shape = util::toVec(args[1].unwrapToIntList().vec()); + std::cout << "====3====: " << args[1].size() << std::endl; + // new_shape = util::toVec(args[1].unwrapToIntList().vec()); + new_shape = util::toVec(args[1].unwrapToITensorList()); std::cout << "====4====" << std::endl; int nbDynamicDims = 0; for (size_t i = 0; i < new_shape.size(); i++) { diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index ff68590e3e..062cea7a64 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -146,6 +146,24 @@ bool Var::isITensor() const { } } +bool Var::isITensorList() const { + LOG_DEBUG("===== TYPE NAME: " << type_name()); + if (type_ == Type::kITensor) { + return true; + } else { + return false; + } +} + +bool Var::unwrapToITensorList() { + TORCHTRT_CHECK( + isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); + LOG_DEBUG("===== TYPE NAME: " << type_name()); + auto ivalue = ptr_.ivalue; + return false; + // return ptr_.ivalue->to(); +} + bool Var::isIValue() const { if (type_ == Type::kIValue) { return true; diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index 6d7edcecde..4167246449 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder { c10::Scalar unwrapToScalar(); c10::List unwrapToIntList(c10::List default_val); c10::List unwrapToIntList(); + c10::List unwrapToITensorList(); c10::List unwrapToDoubleList(c10::List default_val); c10::List unwrapToDoubleList(); c10::List unwrapToBoolList(c10::List default_val); @@ -58,6 +59,7 @@ class Var : torch::CustomClassHolder { bool isIValue() const; bool isITensor() const; + bool isITensorList() const; bool isNone() const; Var::Type type() const; std::string type_name() const; From 4c17994fef5b7aecb40a62441ce876f935733dda Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 Feb 2023 10:56:39 -0800 Subject: [PATCH 03/10] feat: Implement dynamic version of aten::size Signed-off-by: Dheeraj Peri --- core/conversion/converters/BUILD | 3 - core/conversion/converters/impl/shuffle.cpp | 60 ++++---- core/conversion/evaluators/aten.cpp | 162 +++++++++++++------- core/conversion/evaluators/prim.cpp | 5 +- core/conversion/var/Var.cpp | 14 +- core/conversion/var/Var.h | 2 +- py/requirements.txt | 5 +- 7 files changed, 147 insertions(+), 104 deletions(-) diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index eaf15b0c9a..ce8d1dbbb7 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -62,11 +62,8 @@ cc_library( "impl/constant_pad.cpp", "impl/conv_deconv.cpp", "impl/cumsum.cpp", -<<<<<<< HEAD "impl/dual_ops.cpp", -======= "impl/einsum.cpp", ->>>>>>> main "impl/element_wise.cpp", "impl/expand.cpp", "impl/interpolate.cpp", diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 52943ecb5a..18033c1d07 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -67,39 +67,33 @@ static auto shuffle_registrations TORCHTRT_UNUSED = .pattern( {"aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - std::cout << "====1====" << std::endl; - auto in_shape = util::toVec(in->getDimensions()); - std::cout << "====2====" << std::endl; - std::vector new_shape; - if (ctx->input_is_dynamic) { - std::cout << "====3====: " << args[1].size() << std::endl; - // new_shape = util::toVec(args[1].unwrapToIntList().vec()); - new_shape = util::toVec(args[1].unwrapToITensorList()); - std::cout << "====4====" << std::endl; - int nbDynamicDims = 0; - for (size_t i = 0; i < new_shape.size(); i++) { - if (in_shape[i] == -1) - nbDynamicDims++; - } - if (nbDynamicDims > 1) { - TORCHTRT_THROW_ERROR( - "Resize is currently not supported when target shape contains more than one dynamic dimension"); - } - } else { - std::cout << "====5====" << std::endl; - new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); - } - std::cout << "====6====" << std::endl; - auto shuffle = ctx->net->addShuffle(*in); - TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - shuffle->setReshapeDimensions(util::toDims(new_shape)); - shuffle->setName(util::node_info(n).c_str()); - - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - - return true; + auto in = args[0].ITensorOrFreeze(ctx); + auto in_shape = util::toVec(in->getDimensions()); + std::vector new_shape; + nvinfer1::ITensor* shape_tensor; + if (ctx->input_is_dynamic) { + auto new_shape = args[1].unwrapToITensorList(); + auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); + TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); + concat_layer->setAxis(static_cast(0)); + shape_tensor = concat_layer->getOutput(0); + } else { + auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); + } + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setName(util::node_info(n).c_str()); + TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + + if (ctx->input_is_dynamic){ + shuffle->setInput(1, *shape_tensor); + } else { + shuffle->setReshapeDimensions(util::toDims(new_shape)); + } + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; }}) .pattern( {"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))", diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 1b5c255111..efc9a8832e 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -19,6 +19,41 @@ namespace conversion { namespace evaluators { namespace { +nvinfer1::ITensor* index_layer(){ + +} + +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){ + LOG_DEBUG("Using dynamic version of aten::size evaluator"); + auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); + LOG_DEBUG("Input dimensions: " << in->getDimensions()); + auto shape_layer = ctx->net->addShape(*in); + auto shape_1d_tensor = shape_layer->getOutput(0); + + if (n->inputs().size() != 1){ + auto maxDim = static_cast(in->getDimensions().nbDims); + auto dim = args.at(n->input(1)).unwrapToInt(); + // Handle negative axis by refering to nbDims of input Tensor + dim = dim < 0 ? dim + maxDim : dim; + LOG_DEBUG("Dimension to select: " << dim); + + // index to access needs to be an at::Tensor + at::Tensor indices = torch::tensor({dim}).to(torch::kI32); + auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices); + + auto gather_layer = ctx->net->addGather(*shape_1d_tensor, *indices_out, 0); + shape_1d_tensor = gather_layer->getOutput(0); + } + + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); + + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(shape_1d_tensor); + auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + + return shape_1d_ivalue; +} + DEFINE_GENERIC_TWO_INPUT_EVALUATOR( eq, "aten::eq", @@ -176,7 +211,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::full_like"), // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, // Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Override options related to layout and device for TensorRT auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); auto input_tensor_var = args.at(n->input(0)); @@ -262,67 +297,80 @@ auto aten_registrations TORCHTRT_UNUSED = return static_cast(list.size()); }, EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})}) - // .evaluator( - // {c10::Symbol::fromQualString("aten::size"), - // [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - // LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size"); - // auto tensor_var = args.at(n->input(0)); - // if (n->inputs().size() == 1) { - // if (tensor_var.isITensor()) { - // auto tensor = tensor_var.ITensor(); - // return util::toVec(tensor->getDimensions()); - // } else if (tensor_var.IValue()->isTensor()) { - // auto tensor = tensor_var.unwrapToTensor(); - // return tensor.sizes(); - // } else if (tensor_var.IValue()->isCustomClass()) { - // auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); - // return util::toVec(tensor->getDimensions()); - // } else { - // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); - // } - // } else { - // auto dim = args.at(n->input(1)).unwrapToInt(); - // if (tensor_var.isITensor()) { - // auto tensor = tensor_var.ITensor(); - // auto dims = util::toVec(tensor->getDimensions()); - // auto nbDims = tensor->getDimensions().nbDims; - // if (dim < 0) { - // dim += nbDims; - // } - // return dims[dim]; - // } else if (tensor_var.IValue()->isTensor()) { - // auto tensor = tensor_var.unwrapToTensor(); - // auto nbDims = tensor.sizes().size(); - // if (dim < 0) { - // dim += nbDims; - // } - // return tensor.sizes()[dim]; - // } else if (tensor_var.IValue()->isCustomClass()) { - // auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); - // auto dims = util::toVec(tensor->getDimensions()); - // auto nbDims = tensor->getDimensions().nbDims; - // if (dim < 0) { - // dim += nbDims; - // } - // return dims[dim]; - // } else { - // TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); - // } - // } - // }, - // EvalOptions().validSchemas( - // {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) + .evaluator( + {c10::Symbol::fromQualString("aten::size"), + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto tensor_var = args.at(n->input(0)); + if (n->inputs().size() == 1) { + if (tensor_var.isITensor()) { + auto tensor = tensor_var.ITensor(); + if (ctx->input_is_dynamic){ + return dynamic_size_layer(ctx, n, args); + } + return util::toVec(tensor->getDimensions()); + } else if (tensor_var.IValue()->isTensor()) { + auto tensor = tensor_var.unwrapToTensor(); + return tensor.sizes(); + } else if (tensor_var.IValue()->isCustomClass()) { + auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); + return util::toVec(tensor->getDimensions()); + } else { + TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); + } + } else { + auto dim = args.at(n->input(1)).unwrapToInt(); + if (tensor_var.isITensor()) { + if (ctx->input_is_dynamic){ + return dynamic_size_layer(ctx, n, args); + } + auto tensor = tensor_var.ITensor(); + auto dims = util::toVec(tensor->getDimensions()); + auto nbDims = tensor->getDimensions().nbDims; + if (dim < 0) { + dim += nbDims; + } + return dims[dim]; + } else if (tensor_var.IValue()->isTensor()) { + auto tensor = tensor_var.unwrapToTensor(); + auto nbDims = tensor.sizes().size(); + if (dim < 0) { + dim += nbDims; + } + return tensor.sizes()[dim]; + } else if (tensor_var.IValue()->isCustomClass()) { + auto tensor = tensor_var.IValue()->toCustomClass()->tensor(); + auto dims = util::toVec(tensor->getDimensions()); + auto nbDims = tensor->getDimensions().nbDims; + if (dim < 0) { + dim += nbDims; + } + return dims[dim]; + } else { + TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type()); + } + } + }, + EvalOptions().validSchemas( + {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) .evaluator( {c10::Symbol::fromQualString("aten::__getitem__"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).IValue()->to>(); + auto list_input = args.at(n->input(0)); auto idx = args.at(n->input(1)).unwrapToInt(); + if (list_input.isIValue()){ + auto list = args.at(n->input(0)).IValue()->to>(); + const int64_t list_size = list.size(); + const int64_t normalized_idx = normalizeIndex(idx, list_size); + TORCHTRT_CHECK( + normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); + return list.get(normalized_idx); + } elif (list_input.isITensor()){ + return dynamic_size_layer(ctx, n, args); + } + + - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - TORCHTRT_CHECK( - normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); - return list.get(normalized_idx); + }, EvalOptions().validSchemas({ "aten::__getitem__.t(t[](a) list, int idx) -> (t(*))", diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index e6fe0bab56..68e1917724 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -48,6 +48,7 @@ auto prim_registrations = [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { const auto num_inputs = n->inputs().size(); if (constTypesOnly(args)) { + LOG_DEBUG("==== CONST TYPES ARGS ==== "); c10::ListTypePtr lt = n->output()->type()->expect(); if (torch::jit::IntType::get() == lt->getElementType()) { c10::List list; @@ -89,6 +90,7 @@ auto prim_registrations = return c10::optional(std::move(torch::jit::IValue(list))); } } else { + LOG_DEBUG("==== NON CONST TYPES ==== "); c10::ListTypePtr lt = n->output()->type()->expect(); c10::TypePtr elementType = lt->getElementType(); auto list = c10::impl::GenericList(elementType); @@ -104,7 +106,8 @@ auto prim_registrations = auto ival = torch::jit::IValue(); list.emplace_back(std::move(ival)); } else if (args.at(in).IValue()->isInt()) { - auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor(args.at(in).unwrapToInt())); + LOG_DEBUG("==== INT TYPE ITENSOR ==== "); + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()})); auto tensor_holder = TensorContainer(); tensor_holder.hold_tensor(itensor); auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 062cea7a64..2d525fc809 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -147,7 +147,6 @@ bool Var::isITensor() const { } bool Var::isITensorList() const { - LOG_DEBUG("===== TYPE NAME: " << type_name()); if (type_ == Type::kITensor) { return true; } else { @@ -155,13 +154,16 @@ bool Var::isITensorList() const { } } -bool Var::unwrapToITensorList() { +std::vector Var::unwrapToITensorList() { TORCHTRT_CHECK( isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); - LOG_DEBUG("===== TYPE NAME: " << type_name()); - auto ivalue = ptr_.ivalue; - return false; - // return ptr_.ivalue->to(); + auto ivalue_list = ptr_.ivalue->toList(); + std::vector outputs; + for (int i=0; i < ivalue_list.size(); i++){ + auto element = ivalue_list.get(i).toCustomClass()->tensor(); + outputs.push_back(std::move(element)); + } + return outputs; } bool Var::isIValue() const { diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index 4167246449..ec9281cffb 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -43,7 +43,7 @@ class Var : torch::CustomClassHolder { c10::Scalar unwrapToScalar(); c10::List unwrapToIntList(c10::List default_val); c10::List unwrapToIntList(); - c10::List unwrapToITensorList(); + std::vector unwrapToITensorList(); c10::List unwrapToDoubleList(c10::List default_val); c10::List unwrapToDoubleList(); c10::List unwrapToBoolList(c10::List default_val); diff --git a/py/requirements.txt b/py/requirements.txt index 6d0916af9d..6292e6c38b 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,7 +1,6 @@ numpy pybind11==2.6.2 ---extra-index-url https://download.pytorch.org/whl/nightly/cu117 -torch==2.0.0.dev20230103+cu117 -torchvision==0.15.0.dev20230103+cu117 +torch==1.13.0 +torchvision==0.14.0 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.5.1.7 From b2c8f59e2f9a7b08d295ec2ff50c6a4cbd50b995 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 2 Feb 2023 11:32:40 -0800 Subject: [PATCH 04/10] chore: refactor code Signed-off-by: Dheeraj Peri --- core/conversion/evaluators/aten.cpp | 34 +++++++++++++++-------------- core/conversion/evaluators/prim.cpp | 4 +--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index efc9a8832e..6d33f474e8 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -19,8 +19,15 @@ namespace conversion { namespace evaluators { namespace { -nvinfer1::ITensor* index_layer(){ - +nvinfer1::ITensor* index_layer(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){ + // index to access needs to be an at::Tensor + at::Tensor indices = torch::tensor({index}).to(torch::kI32); + auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices); + + auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto indexed_tensor = gather_layer->getOutput(0); + return indexed_tensor; } c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){ @@ -28,6 +35,7 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); LOG_DEBUG("Input dimensions: " << in->getDimensions()); auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); auto shape_1d_tensor = shape_layer->getOutput(0); if (n->inputs().size() != 1){ @@ -36,15 +44,9 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw // Handle negative axis by refering to nbDims of input Tensor dim = dim < 0 ? dim + maxDim : dim; LOG_DEBUG("Dimension to select: " << dim); - - // index to access needs to be an at::Tensor - at::Tensor indices = torch::tensor({dim}).to(torch::kI32); - auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices); - - auto gather_layer = ctx->net->addGather(*shape_1d_tensor, *indices_out, 0); - shape_1d_tensor = gather_layer->getOutput(0); + shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); } - + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); auto tensor_holder = TensorContainer(); @@ -364,13 +366,13 @@ auto aten_registrations TORCHTRT_UNUSED = TORCHTRT_CHECK( normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); return list.get(normalized_idx); - } elif (list_input.isITensor()){ - return dynamic_size_layer(ctx, n, args); + } else if(list_input.isITensor()){ + auto indexed_tensor = index_layer(ctx, n, list_input.ITensorOrFreeze(ctx), idx); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(indexed_tensor); + auto indexed_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + return indexed_ivalue; } - - - - }, EvalOptions().validSchemas({ "aten::__getitem__.t(t[](a) list, int idx) -> (t(*))", diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 68e1917724..caf7498fb3 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -48,7 +48,6 @@ auto prim_registrations = [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { const auto num_inputs = n->inputs().size(); if (constTypesOnly(args)) { - LOG_DEBUG("==== CONST TYPES ARGS ==== "); c10::ListTypePtr lt = n->output()->type()->expect(); if (torch::jit::IntType::get() == lt->getElementType()) { c10::List list; @@ -106,8 +105,7 @@ auto prim_registrations = auto ival = torch::jit::IValue(); list.emplace_back(std::move(ival)); } else if (args.at(in).IValue()->isInt()) { - LOG_DEBUG("==== INT TYPE ITENSOR ==== "); - auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()})); + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32)); auto tensor_holder = TensorContainer(); tensor_holder.hold_tensor(itensor); auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); From 47ae984a96d2b3e3f1ca1fc93fbbd40a71ce89db Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 3 Feb 2023 14:40:44 -0800 Subject: [PATCH 05/10] chore: linter fixes Signed-off-by: Dheeraj Peri --- core/conversion/converters/BUILD | 1 - core/conversion/converters/impl/shuffle.cpp | 54 ++--- core/conversion/evaluators/aten.cpp | 30 ++- core/conversion/evaluators/eval_macros.h | 250 ++++++++++---------- core/conversion/evaluators/evaluators.h | 5 +- core/conversion/evaluators/prim.cpp | 5 +- core/conversion/var/Var.cpp | 2 +- py/requirements.txt | 5 +- 8 files changed, 179 insertions(+), 173 deletions(-) diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index ce8d1dbbb7..354f17a734 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -62,7 +62,6 @@ cc_library( "impl/constant_pad.cpp", "impl/conv_deconv.cpp", "impl/cumsum.cpp", - "impl/dual_ops.cpp", "impl/einsum.cpp", "impl/element_wise.cpp", "impl/expand.cpp", diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 18033c1d07..810af2e352 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -67,33 +67,33 @@ static auto shuffle_registrations TORCHTRT_UNUSED = .pattern( {"aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto in_shape = util::toVec(in->getDimensions()); - std::vector new_shape; - nvinfer1::ITensor* shape_tensor; - if (ctx->input_is_dynamic) { - auto new_shape = args[1].unwrapToITensorList(); - auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); - TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); - concat_layer->setAxis(static_cast(0)); - shape_tensor = concat_layer->getOutput(0); - } else { - auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); - } - auto shuffle = ctx->net->addShuffle(*in); - shuffle->setName(util::node_info(n).c_str()); - TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - - if (ctx->input_is_dynamic){ - shuffle->setInput(1, *shape_tensor); - } else { - shuffle->setReshapeDimensions(util::toDims(new_shape)); - } - - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - - return true; + auto in = args[0].ITensorOrFreeze(ctx); + auto in_shape = util::toVec(in->getDimensions()); + std::vector new_shape; + nvinfer1::ITensor* shape_tensor; + if (ctx->input_is_dynamic) { + auto new_shape = args[1].unwrapToITensorList(); + auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); + TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); + concat_layer->setAxis(static_cast(0)); + shape_tensor = concat_layer->getOutput(0); + } else { + auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); + } + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setName(util::node_info(n).c_str()); + TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + + if (ctx->input_is_dynamic) { + shuffle->setInput(1, *shape_tensor); + } else { + shuffle->setReshapeDimensions(util::toDims(new_shape)); + } + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; }}) .pattern( {"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))", diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 6d33f474e8..d7a2479ffe 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -19,7 +19,11 @@ namespace conversion { namespace evaluators { namespace { -nvinfer1::ITensor* index_layer(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){ +nvinfer1::ITensor* index_layer( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input_tensor, + int64_t index) { // index to access needs to be an at::Tensor at::Tensor indices = torch::tensor({index}).to(torch::kI32); auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices); @@ -30,7 +34,7 @@ nvinfer1::ITensor* index_layer(ConversionCtx* ctx, const torch::jit::Node* n, nv return indexed_tensor; } -c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){ +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { LOG_DEBUG("Using dynamic version of aten::size evaluator"); auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); LOG_DEBUG("Input dimensions: " << in->getDimensions()); @@ -38,7 +42,7 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); auto shape_1d_tensor = shape_layer->getOutput(0); - if (n->inputs().size() != 1){ + if (n->inputs().size() != 1) { auto maxDim = static_cast(in->getDimensions().nbDims); auto dim = args.at(n->input(1)).unwrapToInt(); // Handle negative axis by refering to nbDims of input Tensor @@ -306,7 +310,7 @@ auto aten_registrations TORCHTRT_UNUSED = if (n->inputs().size() == 1) { if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); - if (ctx->input_is_dynamic){ + if (ctx->input_is_dynamic) { return dynamic_size_layer(ctx, n, args); } return util::toVec(tensor->getDimensions()); @@ -322,7 +326,7 @@ auto aten_registrations TORCHTRT_UNUSED = } else { auto dim = args.at(n->input(1)).unwrapToInt(); if (tensor_var.isITensor()) { - if (ctx->input_is_dynamic){ + if (ctx->input_is_dynamic) { return dynamic_size_layer(ctx, n, args); } auto tensor = tensor_var.ITensor(); @@ -359,14 +363,14 @@ auto aten_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list_input = args.at(n->input(0)); auto idx = args.at(n->input(1)).unwrapToInt(); - if (list_input.isIValue()){ - auto list = args.at(n->input(0)).IValue()->to>(); - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - TORCHTRT_CHECK( - normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); - return list.get(normalized_idx); - } else if(list_input.isITensor()){ + if (list_input.isIValue()) { + auto list = args.at(n->input(0)).IValue()->to>(); + const int64_t list_size = list.size(); + const int64_t normalized_idx = normalizeIndex(idx, list_size); + TORCHTRT_CHECK( + normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); + return list.get(normalized_idx); + } else if (list_input.isITensor()) { auto indexed_tensor = index_layer(ctx, n, list_input.ITensorOrFreeze(ctx), idx); auto tensor_holder = TensorContainer(); tensor_holder.hold_tensor(indexed_tensor); diff --git a/core/conversion/evaluators/eval_macros.h b/core/conversion/evaluators/eval_macros.h index f15046afa9..5a0328663b 100644 --- a/core/conversion/evaluators/eval_macros.h +++ b/core/conversion/evaluators/eval_macros.h @@ -2,134 +2,134 @@ #include "core/conversion/evaluators/evaluators.h" -#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ - auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_kind), \ - [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - if (args.at(n->input(0)).IValue()->isInt()) { \ - auto a = args.at(n->input(0)).unwrapToInt(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isDouble()) { \ - auto a = args.at(n->input(0)).unwrapToDouble(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isBool()) { \ - auto a = args.at(n->input(0)).unwrapToBool(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isString()) { \ - auto a = args.at(n->input(0)).unwrapToString(); \ - if (args.at(n->input(1)).IValue()->isString()) { \ - auto b = args.at(n->input(1)).unwrapToString(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ - return {}; \ - } \ - }, \ +#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ + auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_kind), \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + if (args.at(n->input(0)).IValue()->isInt()) { \ + auto a = args.at(n->input(0)).unwrapToInt(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isDouble()) { \ + auto a = args.at(n->input(0)).unwrapToDouble(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isBool()) { \ + auto a = args.at(n->input(0)).unwrapToBool(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isString()) { \ + auto a = args.at(n->input(0)).unwrapToString(); \ + if (args.at(n->input(1)).IValue()->isString()) { \ + auto b = args.at(n->input(1)).unwrapToString(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ + return {}; \ + } \ + }, \ EvalOptions().validSchemas(schemas)}); -#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ - auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_kind), \ - [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - if (args.at(n->input(0)).IValue()->isInt()) { \ - auto a = args.at(n->input(0)).unwrapToInt(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isDouble()) { \ - auto a = args.at(n->input(0)).unwrapToDouble(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ - return {}; \ - } \ - }, \ +#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ + auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_kind), \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + if (args.at(n->input(0)).IValue()->isInt()) { \ + auto a = args.at(n->input(0)).unwrapToInt(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isDouble()) { \ + auto a = args.at(n->input(0)).unwrapToDouble(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ + return {}; \ + } \ + }, \ EvalOptions().validSchemas(schemas)}); -#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ - auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_name), \ +#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ + auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_name), \ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - auto a = args.at(n->input(0)).unwrapTo(); \ - auto b = args.at(n->input(1)).unwrapTo(); \ - return operation; \ - }, \ + auto a = args.at(n->input(0)).unwrapTo(); \ + auto b = args.at(n->input(1)).unwrapTo(); \ + return operation; \ + }, \ EvalOptions().validSchemas(schemas)}); diff --git a/core/conversion/evaluators/evaluators.h b/core/conversion/evaluators/evaluators.h index 7c2eaf588f..ba9610fac7 100644 --- a/core/conversion/evaluators/evaluators.h +++ b/core/conversion/evaluators/evaluators.h @@ -6,9 +6,9 @@ #include "torch/csrc/jit/ir/ir.h" -#include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/conversion/conversionctx/ConversionCtx.h" #include "core/conversion/converters/converter_util.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/conversion/var/Var.h" namespace torch_tensorrt { @@ -35,7 +35,8 @@ inline bool constTypesOnly(kwargs& args) { // to use the node itself to pull out arguments. // This means that you should iterate over node inputs vs. the args // when writing evaluators -typedef std::function(ConversionCtx*, const torch::jit::Node*, kwargs&)> NodeEvaluator; +typedef std::function(ConversionCtx*, const torch::jit::Node*, kwargs&)> + NodeEvaluator; struct EvalOptions { std::set blacklisted_output_types; diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index caf7498fb3..3230942ce1 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -105,13 +105,14 @@ auto prim_registrations = auto ival = torch::jit::IValue(); list.emplace_back(std::move(ival)); } else if (args.at(in).IValue()->isInt()) { - auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32)); + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32)); auto tensor_holder = TensorContainer(); tensor_holder.hold_tensor(itensor); auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); list.emplace_back(std::move(ival)); } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); + list.emplace_back(std::move(args.at(in).unwrapToTensor())); } } } diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 2d525fc809..915c26da31 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -159,7 +159,7 @@ std::vector Var::unwrapToITensorList() { isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); auto ivalue_list = ptr_.ivalue->toList(); std::vector outputs; - for (int i=0; i < ivalue_list.size(); i++){ + for (int i = 0; i < ivalue_list.size(); i++) { auto element = ivalue_list.get(i).toCustomClass()->tensor(); outputs.push_back(std::move(element)); } diff --git a/py/requirements.txt b/py/requirements.txt index 6292e6c38b..6d0916af9d 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,6 +1,7 @@ numpy pybind11==2.6.2 -torch==1.13.0 -torchvision==0.14.0 +--extra-index-url https://download.pytorch.org/whl/nightly/cu117 +torch==2.0.0.dev20230103+cu117 +torchvision==0.15.0.dev20230103+cu117 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.5.1.7 From 6754c794e02de3602ef84e6622aec3da4ca849ae Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 3 Feb 2023 16:28:39 -0800 Subject: [PATCH 06/10] chore: Add testcase Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/shuffle.cpp | 2 +- core/conversion/evaluators/prim.cpp | 17 ++++++++-- py/requirements.txt | 5 ++- tests/cpp/BUILD | 14 +++++++++ tests/cpp/test_dynamic_size.cpp | 35 +++++++++++++++++++++ 5 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 tests/cpp/test_dynamic_size.cpp diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 810af2e352..16c0604aab 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -78,7 +78,7 @@ static auto shuffle_registrations TORCHTRT_UNUSED = concat_layer->setAxis(static_cast(0)); shape_tensor = concat_layer->getOutput(0); } else { - auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); + new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); } auto shuffle = ctx->net->addShuffle(*in); shuffle->setName(util::node_info(n).c_str()); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 3230942ce1..8962e1b856 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -1,12 +1,11 @@ #include -#include "torch/csrc/jit/ir/ir.h" -//#include "torch/csrc/jit/ir/constants.h" #include "ATen/core/List.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" #include "ATen/core/stack.h" #include "c10/util/intrusive_ptr.h" +#include "torch/csrc/jit/ir/ir.h" #include "torch/torch.h" #include "core/conversion/evaluators/eval_macros.h" @@ -111,8 +110,20 @@ auto prim_registrations = tensor_holder.hold_tensor(itensor); auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); list.emplace_back(std::move(ival)); + } else if (args.at(in).IValue()->isDouble()) { + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat)); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, std::move(args.at(in).unwrapToTensor())); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); } } } diff --git a/py/requirements.txt b/py/requirements.txt index 6d0916af9d..6292e6c38b 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,7 +1,6 @@ numpy pybind11==2.6.2 ---extra-index-url https://download.pytorch.org/whl/nightly/cu117 -torch==2.0.0.dev20230103+cu117 -torchvision==0.15.0.dev20230103+cu117 +torch==1.13.0 +torchvision==0.14.0 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.5.1.7 diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index c34aa09372..709187e1b2 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -16,6 +16,7 @@ test_suite( ":test_compiled_modules", ":test_default_input_types", ":test_dynamic_fallback", + ":test_dynamic_size", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -32,6 +33,7 @@ test_suite( ":test_compiled_modules", ":test_default_input_types", ":test_dynamic_fallback", + ":test_dynamic_size", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -142,6 +144,18 @@ cc_test( }), ) +cc_test( + name = "test_dynamic_size", + srcs = ["test_dynamic_size.cpp"], + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), +) + cc_test( name = "test_collections", srcs = ["test_collections.cpp"], diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp new file mode 100644 index 0000000000..968611deda --- /dev/null +++ b/tests/cpp/test_dynamic_size.cpp @@ -0,0 +1,35 @@ +#include +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenResizeDynamicInputCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %3 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=-1]() + %28 : int = aten::size(%x, %3) + %30 : int[] = prim::ListConstruct(%28, %2) + %6 : Tensor = aten::reshape(%x, %30) + return (%6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} From 6c69b419d886ff4736a34f05b6810b921883cb7b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 8 Feb 2023 17:01:42 -0800 Subject: [PATCH 07/10] chore: Add ivalue type detections Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/shuffle.cpp | 22 +++++-- core/conversion/evaluators/prim.cpp | 1 - core/conversion/var/Var.cpp | 65 ++++++++++++++++++++- core/conversion/var/Var.h | 9 ++- tests/cpp/test_dynamic_size.cpp | 28 ++++++++- 5 files changed, 116 insertions(+), 9 deletions(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 16c0604aab..9817e9a6d3 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -72,11 +72,23 @@ static auto shuffle_registrations TORCHTRT_UNUSED = std::vector new_shape; nvinfer1::ITensor* shape_tensor; if (ctx->input_is_dynamic) { - auto new_shape = args[1].unwrapToITensorList(); - auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); - TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); - concat_layer->setAxis(static_cast(0)); - shape_tensor = concat_layer->getOutput(0); + LOG_DEBUG("Using dynamic version of reshape layer"); + if (args[1].isITensorList()) { + LOG_DEBUG("Shape tensor is an ITensorList"); + auto new_shape = args[1].unwrapToITensorList(); + auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); + TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); + concat_layer->setAxis(static_cast(0)); + shape_tensor = concat_layer->getOutput(0); + } else if (args[1].isIntList()) { + LOG_DEBUG("Shape tensor is an IntList"); + auto shape_vec = args[1].unwrapToIntList().vec(); + shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32)); + } else { + LOG_ERROR( + "Invalid IValue type of " << args[1].ivalue_type() + << " detected for shape tensor from node: " << *n); + } } else { new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); } diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 8962e1b856..c4f2acd7bd 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -88,7 +88,6 @@ auto prim_registrations = return c10::optional(std::move(torch::jit::IValue(list))); } } else { - LOG_DEBUG("==== NON CONST TYPES ==== "); c10::ListTypePtr lt = n->output()->type()->expect(); c10::TypePtr elementType = lt->getElementType(); auto list = c10::impl::GenericList(elementType); diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 915c26da31..08fb63bc4e 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -21,6 +21,28 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) { ptr_.tensor = p; } +Var::IValueType Var::determineIValueType(torch::jit::IValue* p) { + if (p->isInt()) { + return IValueType::kInt; + } else if (p->isDouble()) { + return IValueType::kDouble; + } else if (p->isBool()) { + return IValueType::kBool; + } else if (p->isTensor()) { + return IValueType::kTensor; + } else if (p->isIntList()) { + return IValueType::kIntList; + } else if (p->isDoubleList()) { + return IValueType::kDoubleList; + } else if (p->isBoolList()) { + return IValueType::kBoolList; + } else if (p->isTensorList()) { + return IValueType::kTensorList; + } else if (p->isList()) { + return IValueType::kITensorList; + } +} + Var::Var(const Var& a) { switch (a.type_) { case Type::kITensor: @@ -30,6 +52,7 @@ Var::Var(const Var& a) { case Type::kIValue: ptr_.ivalue = a.ptr_.ivalue; type_ = Type::kIValue; + ivalue_type_ = determineIValueType(ptr_.ivalue); break; case Type::kNone: default: @@ -47,6 +70,7 @@ Var& Var::operator=(const Var& a) { case Type::kIValue: ptr_.ivalue = a.ptr_.ivalue; type_ = Type::kIValue; + ivalue_type_ = determineIValueType(ptr_.ivalue); break; case Type::kNone: default: @@ -59,6 +83,7 @@ Var& Var::operator=(const Var& a) { Var& Var::operator=(torch::jit::IValue* in) { ptr_.ivalue = in; type_ = Type::kIValue; + ivalue_type_ = determineIValueType(ptr_.ivalue); return (*this); } @@ -72,6 +97,10 @@ Var::Type Var::type() const { return type_; } +Var::IValueType Var::ivalue_type() const { + return ivalue_type_; +} + std::string Var::type_name() const { switch (type_) { case Type::kITensor: @@ -147,7 +176,39 @@ bool Var::isITensor() const { } bool Var::isITensorList() const { - if (type_ == Type::kITensor) { + if (ivalue_type_ == IValueType::kITensorList) { + return true; + } else { + return false; + } +} + +bool Var::isIntList() const { + if (ivalue_type_ == IValueType::kIntList) { + return true; + } else { + return false; + } +} + +bool Var::isDoubleList() const { + if (ivalue_type_ == IValueType::kDoubleList) { + return true; + } else { + return false; + } +} + +bool Var::isTensorList() const { + if (ivalue_type_ == IValueType::kTensorList) { + return true; + } else { + return false; + } +} + +bool Var::isBoolList() const { + if (ivalue_type_ == IValueType::kBoolList) { return true; } else { return false; @@ -157,6 +218,8 @@ bool Var::isITensorList() const { std::vector Var::unwrapToITensorList() { TORCHTRT_CHECK( isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); + LOG_DEBUG(" === Is INT list: " << ptr_.ivalue->isIntList()); + LOG_DEBUG(" === Is List: " << ptr_.ivalue->isList()); auto ivalue_list = ptr_.ivalue->toList(); std::vector outputs; for (int i = 0; i < ivalue_list.size(); i++) { diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index ec9281cffb..eb8f46b0e7 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -15,7 +15,7 @@ namespace conversion { class Var : torch::CustomClassHolder { public: enum Type { kITensor, kIValue, kNone }; - + enum IValueType { kInt, kDouble, kBool, kTensor, kIntList, kDoubleList, kBoolList, kTensorList, kITensorList }; Var(); Var(torch::jit::IValue* p); Var(nvinfer1::ITensor* p); @@ -60,9 +60,15 @@ class Var : torch::CustomClassHolder { bool isIValue() const; bool isITensor() const; bool isITensorList() const; + bool isTensorList() const; + bool isDoubleList() const; + bool isIntList() const; + bool isBoolList() const; bool isNone() const; Var::Type type() const; + Var::IValueType ivalue_type() const; std::string type_name() const; + Var::IValueType determineIValueType(torch::jit::IValue* p); private: union VarContainer { @@ -73,6 +79,7 @@ class Var : torch::CustomClassHolder { VarContainer ptr_; Type type_; + IValueType ivalue_type_; }; } // namespace conversion diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index 968611deda..b3d0289631 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -5,7 +5,7 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" -TEST(Converters, ATenResizeDynamicInputCorrectly) { +TEST(Converters, ATenResizeDynamicShapeCorrectly) { const auto graph = R"IR( graph(%x : Tensor): %3 : int = prim::Constant[value=0]() @@ -33,3 +33,29 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } + +TEST(Converters, ATenResizeDynamicInputCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %2 : int[] = prim::Constant[value=[-1, 4, 64]]() + %3 : Tensor = aten::reshape(%x, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} From 10aaaf4aac965708e275e26f045bba35e29e117e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 9 Feb 2023 00:36:24 -0800 Subject: [PATCH 08/10] chore: Refactor utilities and support new Var utils and testcase Signed-off-by: Dheeraj Peri --- core/conversion/evaluators/aten.cpp | 41 --------------------- core/conversion/evaluators/eval_util.cpp | 45 +++++++++++++++++++++++- core/conversion/evaluators/eval_util.h | 9 +++++ core/conversion/evaluators/prim.cpp | 5 ++- core/conversion/var/Var.cpp | 6 ++-- tests/cpp/test_dynamic_size.cpp | 30 ++++++++++++++++ 6 files changed, 89 insertions(+), 47 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index d7a2479ffe..d78bf9878c 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -19,47 +19,6 @@ namespace conversion { namespace evaluators { namespace { -nvinfer1::ITensor* index_layer( - ConversionCtx* ctx, - const torch::jit::Node* n, - nvinfer1::ITensor* input_tensor, - int64_t index) { - // index to access needs to be an at::Tensor - at::Tensor indices = torch::tensor({index}).to(torch::kI32); - auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices); - - auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto indexed_tensor = gather_layer->getOutput(0); - return indexed_tensor; -} - -c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { - LOG_DEBUG("Using dynamic version of aten::size evaluator"); - auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); - LOG_DEBUG("Input dimensions: " << in->getDimensions()); - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - auto shape_1d_tensor = shape_layer->getOutput(0); - - if (n->inputs().size() != 1) { - auto maxDim = static_cast(in->getDimensions().nbDims); - auto dim = args.at(n->input(1)).unwrapToInt(); - // Handle negative axis by refering to nbDims of input Tensor - dim = dim < 0 ? dim + maxDim : dim; - LOG_DEBUG("Dimension to select: " << dim); - shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); - } - - LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); - - auto tensor_holder = TensorContainer(); - tensor_holder.hold_tensor(shape_1d_tensor); - auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); - - return shape_1d_ivalue; -} - DEFINE_GENERIC_TWO_INPUT_EVALUATOR( eq, "aten::eq", diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index c14f9a6714..fcd8f0c910 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -1,3 +1,4 @@ +#include "core/conversion/evaluators/eval_util.h" #include #include "ATen/InitialTensorOptions.h" #include "ATen/core/List.h" @@ -6,12 +7,54 @@ #include "ATen/core/jit_type.h" #include "c10/util/irange.h" #include "core/util/prelude.h" +#include "torch/torch.h" namespace torch_tensorrt { namespace core { namespace conversion { namespace evaluators { +nvinfer1::ITensor* index_layer( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input_tensor, + int64_t index) { + // index to access needs to be an at::Tensor + at::Tensor indices = torch::tensor({index}).to(torch::kI32); + auto indices_out = converters::tensor_to_const(ctx, indices); + + auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto indexed_tensor = gather_layer->getOutput(0); + return indexed_tensor; +} + +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { + LOG_DEBUG("Using dynamic version of aten::size evaluator"); + auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); + LOG_DEBUG("Input dimensions: " << in->getDimensions()); + auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + auto shape_1d_tensor = shape_layer->getOutput(0); + + if (n->inputs().size() != 1) { + auto maxDim = static_cast(in->getDimensions().nbDims); + auto dim = args.at(n->input(1)).unwrapToInt(); + // Handle negative axis by refering to nbDims of input Tensor + dim = dim < 0 ? dim + maxDim : dim; + LOG_DEBUG("Dimension to select: " << dim); + shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); + } + + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); + + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(shape_1d_tensor); + auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + + return shape_1d_ivalue; +} + int64_t normalizeIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing @@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { } // TODO: Conditionally enable truncation based on user setting -at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) { +at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) { // This function is basically same with the one in // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float // won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h index c63ead7461..5d0f050981 100644 --- a/core/conversion/evaluators/eval_util.h +++ b/core/conversion/evaluators/eval_util.h @@ -1,5 +1,6 @@ #pragma once +#include "core/conversion/evaluators/evaluators.h" #include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { @@ -7,6 +8,14 @@ namespace core { namespace conversion { namespace evaluators { +nvinfer1::ITensor* index_layer( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input_tensor, + int64_t index); + +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args); + c10::optional toIValue(const torch::jit::Value* v); at::Tensor createTensorFromList( const torch::jit::IValue& data, diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index c4f2acd7bd..cbbc109982 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -88,9 +88,8 @@ auto prim_registrations = return c10::optional(std::move(torch::jit::IValue(list))); } } else { - c10::ListTypePtr lt = n->output()->type()->expect(); - c10::TypePtr elementType = lt->getElementType(); - auto list = c10::impl::GenericList(elementType); + // List would be of IValues (with ITensors embedded in them) + auto list = c10::impl::GenericList(c10::AnyType::get()); list.reserve(num_inputs); for (auto in : n->inputs()) { if (args.at(in).isITensor()) { diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 08fb63bc4e..68f5b9c5be 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -218,8 +218,10 @@ bool Var::isBoolList() const { std::vector Var::unwrapToITensorList() { TORCHTRT_CHECK( isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); - LOG_DEBUG(" === Is INT list: " << ptr_.ivalue->isIntList()); - LOG_DEBUG(" === Is List: " << ptr_.ivalue->isList()); + TORCHTRT_CHECK( + isITensorList(), + "Expected IValue to be an ITensorList, however the type is " + << static_cast::type>(ivalue_type_)); auto ivalue_list = ptr_.ivalue->toList(); std::vector outputs; for (int i = 0; i < ivalue_list.size(); i++) { diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index b3d0289631..202b4f5ddc 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -59,3 +59,33 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } + +TEST(Converters, ATenResizeGetItemDynShapeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=-1]() + %2 : int = prim::Constant[value=0]() + %size.1 : int[] = aten::size(%x.1) + %37 : int = aten::__getitem__(%size.1, %2) + %39 : int[] = prim::ListConstruct(%37, %3) + %7 : Tensor = aten::reshape(%x.1, %39) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file From 34e7b098d99cc29bf2d4d012266a89a99dba96fb Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 14 Feb 2023 14:05:15 -0800 Subject: [PATCH 09/10] refactor: Implement a macro for pytorch type checking Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/shuffle.cpp | 2 +- core/conversion/var/Var.cpp | 70 +-------------------- core/conversion/var/Var.h | 24 ++++--- core/conversion/var/Var_inl.h | 19 ++++++ 4 files changed, 38 insertions(+), 77 deletions(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 9817e9a6d3..f758c0cc47 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -86,7 +86,7 @@ static auto shuffle_registrations TORCHTRT_UNUSED = shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32)); } else { LOG_ERROR( - "Invalid IValue type of " << args[1].ivalue_type() + "Invalid IValue type of " << args[1].IValue()->type() << " detected for shape tensor from node: " << *n); } } else { diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 68f5b9c5be..3ad234b15d 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -21,28 +21,6 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) { ptr_.tensor = p; } -Var::IValueType Var::determineIValueType(torch::jit::IValue* p) { - if (p->isInt()) { - return IValueType::kInt; - } else if (p->isDouble()) { - return IValueType::kDouble; - } else if (p->isBool()) { - return IValueType::kBool; - } else if (p->isTensor()) { - return IValueType::kTensor; - } else if (p->isIntList()) { - return IValueType::kIntList; - } else if (p->isDoubleList()) { - return IValueType::kDoubleList; - } else if (p->isBoolList()) { - return IValueType::kBoolList; - } else if (p->isTensorList()) { - return IValueType::kTensorList; - } else if (p->isList()) { - return IValueType::kITensorList; - } -} - Var::Var(const Var& a) { switch (a.type_) { case Type::kITensor: @@ -52,7 +30,6 @@ Var::Var(const Var& a) { case Type::kIValue: ptr_.ivalue = a.ptr_.ivalue; type_ = Type::kIValue; - ivalue_type_ = determineIValueType(ptr_.ivalue); break; case Type::kNone: default: @@ -70,7 +47,6 @@ Var& Var::operator=(const Var& a) { case Type::kIValue: ptr_.ivalue = a.ptr_.ivalue; type_ = Type::kIValue; - ivalue_type_ = determineIValueType(ptr_.ivalue); break; case Type::kNone: default: @@ -83,7 +59,6 @@ Var& Var::operator=(const Var& a) { Var& Var::operator=(torch::jit::IValue* in) { ptr_.ivalue = in; type_ = Type::kIValue; - ivalue_type_ = determineIValueType(ptr_.ivalue); return (*this); } @@ -97,10 +72,6 @@ Var::Type Var::type() const { return type_; } -Var::IValueType Var::ivalue_type() const { - return ivalue_type_; -} - std::string Var::type_name() const { switch (type_) { case Type::kITensor: @@ -175,40 +146,8 @@ bool Var::isITensor() const { } } -bool Var::isITensorList() const { - if (ivalue_type_ == IValueType::kITensorList) { - return true; - } else { - return false; - } -} - -bool Var::isIntList() const { - if (ivalue_type_ == IValueType::kIntList) { - return true; - } else { - return false; - } -} - -bool Var::isDoubleList() const { - if (ivalue_type_ == IValueType::kDoubleList) { - return true; - } else { - return false; - } -} - -bool Var::isTensorList() const { - if (ivalue_type_ == IValueType::kTensorList) { - return true; - } else { - return false; - } -} - -bool Var::isBoolList() const { - if (ivalue_type_ == IValueType::kBoolList) { +bool Var::isITensorList() { + if (isList() && ptr_.ivalue->isCustomClass()) { return true; } else { return false; @@ -218,10 +157,7 @@ bool Var::isBoolList() const { std::vector Var::unwrapToITensorList() { TORCHTRT_CHECK( isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); - TORCHTRT_CHECK( - isITensorList(), - "Expected IValue to be an ITensorList, however the type is " - << static_cast::type>(ivalue_type_)); + TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList"); auto ivalue_list = ptr_.ivalue->toList(); std::vector outputs; for (int i = 0; i < ivalue_list.size(); i++) { diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index eb8f46b0e7..41889cbbbb 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -15,7 +15,7 @@ namespace conversion { class Var : torch::CustomClassHolder { public: enum Type { kITensor, kIValue, kNone }; - enum IValueType { kInt, kDouble, kBool, kTensor, kIntList, kDoubleList, kBoolList, kTensorList, kITensorList }; + Var(); Var(torch::jit::IValue* p); Var(nvinfer1::ITensor* p); @@ -59,16 +59,23 @@ class Var : torch::CustomClassHolder { bool isIValue() const; bool isITensor() const; - bool isITensorList() const; - bool isTensorList() const; - bool isDoubleList() const; - bool isIntList() const; - bool isBoolList() const; bool isNone() const; + + bool isInt(); + bool isDouble(); + bool isBool(); + bool isString(); + bool isScalar(); + bool isTensor(); + bool isIntList(); + bool isDoubleList(); + bool isBoolList(); + bool isTensorList(); + bool isITensorList(); + bool isList(); + Var::Type type() const; - Var::IValueType ivalue_type() const; std::string type_name() const; - Var::IValueType determineIValueType(torch::jit::IValue* p); private: union VarContainer { @@ -79,7 +86,6 @@ class Var : torch::CustomClassHolder { VarContainer ptr_; Type type_; - IValueType ivalue_type_; }; } // namespace conversion diff --git a/core/conversion/var/Var_inl.h b/core/conversion/var/Var_inl.h index 13760a908c..a98519abe1 100644 --- a/core/conversion/var/Var_inl.h +++ b/core/conversion/var/Var_inl.h @@ -4,6 +4,13 @@ namespace torch_tensorrt { namespace core { namespace conversion { +#define DEFINE_IS_IVAL_TYPE(method_variant) \ + inline bool Var::is##method_variant() { \ + TORCHTRT_CHECK( \ + isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); \ + return ptr_.ivalue->is##method_variant(); \ + } + #define DEFINE_UNWRAP_TO(ival_type, method_variant) \ template <> \ inline ival_type Var::unwrapTo() { \ @@ -34,6 +41,18 @@ namespace conversion { return this->unwrapTo(); \ } +DEFINE_IS_IVAL_TYPE(Int) +DEFINE_IS_IVAL_TYPE(Double) +DEFINE_IS_IVAL_TYPE(Bool) +DEFINE_IS_IVAL_TYPE(String) +DEFINE_IS_IVAL_TYPE(Scalar) +DEFINE_IS_IVAL_TYPE(Tensor) +DEFINE_IS_IVAL_TYPE(IntList) +DEFINE_IS_IVAL_TYPE(DoubleList) +DEFINE_IS_IVAL_TYPE(BoolList) +DEFINE_IS_IVAL_TYPE(TensorList) +DEFINE_IS_IVAL_TYPE(List) + DEFINE_UNWRAP_TO(at::Tensor, Tensor) DEFINE_UNWRAP_TO(int64_t, Int) DEFINE_UNWRAP_TO(double, Double) From 76dc8046972c7ff368d36147388b093b6d401946 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 5 Apr 2023 15:34:18 -0700 Subject: [PATCH 10/10] fix: Fix how ITensorList is detected Signed-off-by: Dheeraj Peri --- core/conversion/var/Var.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 3ad234b15d..df8dc2e544 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -147,11 +147,15 @@ bool Var::isITensor() const { } bool Var::isITensorList() { - if (isList() && ptr_.ivalue->isCustomClass()) { - return true; - } else { - return false; + // Unpack the Var as a List and check if each entry is a custom class since + // ITensors are stored in CustomClassHolder + auto ival_list = ptr_.ivalue->toList(); + for (int i = 0; i < ival_list.size(); i++) { + if (!ival_list.get(i).isCustomClass()) { + return false; + } } + return true; } std::vector Var::unwrapToITensorList() {