diff --git a/.gitignore b/.gitignore index 7fce9305c8..05079c76df 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ experiments/ py/build/ py/tmp/ py/.eggs - \ No newline at end of file +.vscode/ \ No newline at end of file diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 6ce3571012..74bf320e1e 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -10,7 +10,7 @@ namespace trtorch { namespace core { namespace conversion { -// Defined in core/conversion/conversion_blacklist.cpp +// Defined in core/conversion/conversion_blacklist.cpp bool isNodeConversionBlacklisted(const torch::jit::Node* n); bool OpSupported(const torch::jit::Node* n) { @@ -24,8 +24,8 @@ c10::optional EvaluateNode(ConversionCtx* ctx, const torch:: // Also probably a better way to deal with the two error cases; TRTORCH_CHECK(level < limit, "Failed to evaluate node: " << *n \ << "Reason: Exceeded evaluation stack limit (limit=" \ - << limit << ")"); - + << limit << ")"); + LOG_DEBUG(ctx->logger, "Evaluating " << util::node_info(n)); evaluators::kwargs eval_args; for (auto eval_in : n->inputs()) { @@ -55,7 +55,7 @@ c10::optional EvaluateNode(ConversionCtx* ctx, const torch:: return eval; } -bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { +void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { LOG_INFO(ctx->logger, "Adding Layer " << util::node_info(n) << " (ctx.AddLayer)"); converters::args node_args; @@ -87,36 +87,34 @@ bool AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { TRTORCH_THROW_ERROR("Unable to retrieve all node inputs for node: " \ << util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: " \ << *input_node); - return false; } - } if (n->inputs().size() != node_args.size()) { TRTORCH_THROW_ERROR("Unable to retrieve all node inputs for node: " << *n); - return false; } - + auto schema = n->maybeSchema(); TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \ << " (conversion.AddLayer)"); - + auto converter = converters::get_node_converter_for(schema); TRTORCH_CHECK(converter, "Unable to convert node: " << util::node_info(n) \ << " (conversion.AddLayer)\nSchema: " << *schema << "\nConverter for " << schema->name() << " requested, but no such converter was found.\nIf you need a converter for this operator, you can try implementing one yourself\n" - << "or request a converter: https://www.github.com/NVIDIA/TRTorch/issues"); - converter(ctx, n, node_args); + << "or request a converter: https://www.github.com/NVIDIA/TRTorch/issues"); - return true; + TRTORCH_CHECK(converter(ctx, n, node_args), + "Converter for " << *schema << " failed to convert node: " + << util::node_info(n) << "please report this error to https://www.github.com/NVIDIA/TRTorch/issues"); } -bool AddInputs(ConversionCtx* ctx, +void AddInputs(ConversionCtx* ctx, at::ArrayRef inputs, std::vector& input_dims) { - + auto type_lut = torch::jit::script::string_to_type_lut(); std::vector input_tensors; for (auto in : inputs) { @@ -130,7 +128,7 @@ bool AddInputs(ConversionCtx* ctx, input_tensors.push_back(in); } } - + TRTORCH_CHECK(input_tensors.size() == input_dims.size(), "Expected dimension specifications for all input tensors" \ << ", but found " << input_tensors.size() \ @@ -138,7 +136,7 @@ bool AddInputs(ConversionCtx* ctx, << input_dims.size() << "dimension specs (conversion.AddInputs)"); auto profile = ctx->builder->createOptimizationProfile(); - + for (size_t i = 0; i < input_tensors.size(); i++) { auto in = input_tensors[i]; auto dims = input_dims[i]; @@ -158,20 +156,23 @@ bool AddInputs(ConversionCtx* ctx, } TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)"); - + ctx->cfg->addOptimizationProfile(profile); - return true; } -bool MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { +void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { for (auto out : outputs) { - ctx->net->markOutput(*(ctx->value_tensor_map[out])); + auto it = ctx->value_tensor_map.find(out); + // Leaves the potential for unused outputs to be populated with nullptr "safely" + TRTORCH_CHECK(it != ctx->value_tensor_map.end() && it->second, + "No corresponding output TRT Tensor found for TorchScript output: " << out->debugName()); + auto out_tensor = it->second; + ctx->net->markOutput(*out_tensor); LOG_INFO(ctx->logger, "Marking Output " << out->debugName() << " (ctx.MarkOutput)"); } - return true; } - + void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) { for (auto p : params) { ctx->evaluated_value_map[p.first] = torch::jit::IValue(p.second.clone()); @@ -191,13 +192,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI bool to_eval = evaluators::shouldEvalAtConversionTime(n); bool blacklisted = isNodeConversionBlacklisted(n); if (!to_eval && !blacklisted) { - if (!AddLayer(ctx, n)) { - //TODO: Exception things - LOG_ERROR(ctx->logger, - "Failed to add layer: " << *n \ - << " (ctx.AddLayer)"); - return; - } + // Should error out if something fails + AddLayer(ctx, n); } else { std::string reason = ""; if (to_eval) { @@ -207,7 +203,13 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI reason += " (explicitly blacklisted)"; } LOG_DEBUG(ctx->logger, - "Skipping Node: " << (n->kind().toQualString()) << reason); + "Skipping Node: " << util::node_info(n) << reason); + } + } + + for (const auto n : nodes) { + if (converters::node_is_convertable(n)) { + ctx->CheckLayerAddition(n); } } @@ -218,7 +220,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI // Converts a already lowered block (blocks with no sub blocks) to // a serialized TensorRT engine that can be deserialized and run -// Probably should consolidate these two functions +// Probably should consolidate these two functions std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { ConversionCtx ctx(build_info.engine_settings); ConvertBlockToNetDef(&ctx, b, build_info, static_params); @@ -247,7 +249,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { for (auto s : unsupported_ops) { unsupported_msg << " - " << s << std::endl; } - unsupported_msg << "You can either implement converters for these ops in your application or file a bug" << std::endl; + unsupported_msg << "You can either implement converters for these ops in your application or request implementation" << std::endl; unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl; LOG_ERROR(unsupported_msg.str()); } diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index a569e202ff..06f6317012 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -37,11 +37,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) switch(settings.op_precision) { case nvinfer1::DataType::kHALF: cfg->setFlag(nvinfer1::BuilderFlag::kFP16); - input_type = nvinfer1::DataType::kHALF; + input_type = nvinfer1::DataType::kHALF; break; // case nvinfer1::DataType::kINT8: // cfg->setFlag(nvinfer1::BuilderFlag::kINT8); - // input_type = nvinfer1::DataType::kFLOAT; + // input_type = nvinfer1::DataType::kFLOAT; // break; case nvinfer1::DataType::kFLOAT: default: @@ -80,13 +80,30 @@ ConversionCtx::~ConversionCtx() { free(ptr); } } - + +nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) { + tensor->setName(value->debugName().c_str()); + this->value_tensor_map[value] = tensor; + return tensor; +} + std::string ConversionCtx::SerializeEngine() { auto engine = builder->buildEngineWithConfig(*net, *cfg); auto serialized_engine = engine->serialize(); return std::string((const char*)serialized_engine->data(), serialized_engine->size()); } +bool ConversionCtx::CheckLayerAddition(const torch::jit::Node* n) { + for (auto out : n->outputs()) { + auto iter = this->value_tensor_map.find(out); + if (iter == this->value_tensor_map.end()) { + LOG_WARNING("Node " << util::node_info(n) << " output: " << out->debugName() << " does not have a coresponding output, may potentially indicate a defective converter"); + return false; + } + } + return true; +} + } // namespace conversion } // namespace core } // namespace trtorch diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index ff92df6232..06f3755490 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -30,12 +30,15 @@ struct BuilderSettings { BuilderSettings() = default; BuilderSettings(const BuilderSettings& other) = default; - friend std::ostream& operator<<(std::ostream& os, const BuilderSettings& s); + friend std::ostream& operator<<(std::ostream& os, const BuilderSettings& s); }; - + struct ConversionCtx { ConversionCtx(BuilderSettings settings); std::string SerializeEngine(); + nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor); + bool CheckLayerAddition(const torch::jit::Node* n); + ~ConversionCtx(); nvinfer1::IBuilder* builder; @@ -50,12 +53,12 @@ struct ConversionCtx { // is constructed from a PyTorch Tensor it allocates the data here to store a // copy of the values std::vector builder_resources; - + std::unordered_map value_tensor_map; std::unordered_map evaluated_value_map; }; -} // namespace conversion +} // namespace conversion } // namespace core } // namespace trtorch - + diff --git a/core/conversion/converters/Arg.cpp b/core/conversion/converters/Arg.cpp index 87e0ec794a..af23713ea1 100644 --- a/core/conversion/converters/Arg.cpp +++ b/core/conversion/converters/Arg.cpp @@ -85,9 +85,9 @@ std::string Arg::type_name() const { default: return "None"; } - + } - + const torch::jit::IValue* Arg::IValue() const { if (type_ == Type::kIValue) { return ptr_.ivalue; @@ -150,7 +150,7 @@ double Arg::unwrapToDouble(double default_val) { double Arg::unwrapToDouble() { return this->unwrapTo(); -} +} bool Arg::unwrapToBool(bool default_val) { return this->unwrapTo(default_val); @@ -194,26 +194,41 @@ c10::List Arg::unwrapToBoolList() { template T Arg::unwrapTo(T default_val) { - if (isIValue()) { - // TODO: implement Tag Checking - return ptr_.ivalue->to(); + try { + return this->unwrapTo(); + } catch(trtorch::Error& e) { + LOG_DEBUG("In arg unwrapping, returning default value provided (" << e.what() << ")"); + return default_val; } - LOG_DEBUG("In arg unwrapping, returning default value provided"); - return default_val; } - template T Arg::unwrapTo() { - if (isIValue()) { - //TODO: Implement Tag checking - return ptr_.ivalue->to(); - //TODO: Exception - //LOG_INTERNAL_ERROR("Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << ptr_.ivalue->type()); - + TRTORCH_CHECK(isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); + auto ivalue = ptr_.ivalue; + bool correct_type = false; + if (typeid(T) == typeid(double)) { + correct_type = ivalue->isDouble(); + } else if (typeid(T) == typeid(bool)) { + correct_type = ivalue->isBool(); + } else if (typeid(T) == typeid(int64_t)) { + correct_type = ivalue->isInt(); + } else if (typeid(T) == typeid(at::Tensor)) { + correct_type = ivalue->isTensor(); + } else if (typeid(T) == typeid(c10::Scalar)) { + correct_type = ivalue->isScalar(); + } else if (typeid(T) == typeid(c10::List)) { + correct_type = ivalue->isIntList(); + } else if (typeid(T) == typeid(c10::List)) { + correct_type = ivalue->isDoubleList(); + } else if (typeid(T) == typeid(c10::List)) { + correct_type = ivalue->isBoolList(); + } else { + TRTORCH_THROW_ERROR("Requested unwrapping of arg to an unsupported type: " << typeid(T).name()); } - TRTORCH_THROW_ERROR("Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); - return T(); + + TRTORCH_CHECK(correct_type, "Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << *(ptr_.ivalue->type())); + return ptr_.ivalue->to(); } diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 870a054c70..43b30857ee 100644 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -16,6 +16,7 @@ cc_library( "impl/element_wise.cpp", "impl/linear.cpp", "impl/pooling.cpp", + "impl/reduce.cpp", "impl/softmax.cpp", "impl/unary.cpp", ], diff --git a/core/conversion/converters/NodeConverterRegistry.cpp b/core/conversion/converters/NodeConverterRegistry.cpp index d4fe895392..ff175c6023 100644 --- a/core/conversion/converters/NodeConverterRegistry.cpp +++ b/core/conversion/converters/NodeConverterRegistry.cpp @@ -41,20 +41,20 @@ std::string canonical_schema_string(const torch::jit::FunctionSchema& schema) { } namespace { -using ConverterLUT = std::unordered_map; +using ConverterLUT = std::unordered_map; class NodeConverterRegistry { public: bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) { LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature)); - auto sym = torch::jit::Symbol::fromQualString(signature->name()); - converter_lut_[sym] = std::move(converter); + auto name = signature->operator_name(); + converter_lut_[name] = std::move(converter); return true; } OpConverter GetConverter(const torch::jit::FunctionSchema* signature) { - auto sym = torch::jit::Symbol::fromQualString(signature->name()); - auto iter = converter_lut_.find(sym); + auto name = signature->operator_name(); + auto iter = converter_lut_.find(name); if (iter == converter_lut_.end()) { LOG_ERROR("Requested converter for " << signature->name() << ", but no such converter was found"); // ASK: Is there a better way than returning a nullptr? @@ -66,8 +66,8 @@ class NodeConverterRegistry { bool Convertable(const torch::jit::Node* n) { auto schema = n->maybeSchema(); if (schema) { - auto sym = torch::jit::Symbol::fromQualString(schema->name()); - auto iter = converter_lut_.find(sym); + auto name = schema->operator_name(); + auto iter = converter_lut_.find(name); if (iter == converter_lut_.end()) { return false; } else { @@ -79,7 +79,7 @@ class NodeConverterRegistry { return false; } } - + private: ConverterLUT converter_lut_; }; @@ -111,7 +111,7 @@ OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature) bool node_is_convertable(const torch::jit::Node* n) { return get_converter_registry().Convertable(n); } - + RegisterNodeConversionPatterns&& RegisterNodeConversionPatterns::pattern(ConversionPattern p) && { register_node_converter(std::move(p)); return std::move(*this); diff --git a/core/conversion/converters/converters.h b/core/conversion/converters/converters.h index bc051f17bd..1b6f3b916f 100644 --- a/core/conversion/converters/converters.h +++ b/core/conversion/converters/converters.h @@ -69,12 +69,12 @@ class Arg { ArgContainer ptr_; Type type_; }; - - + + typedef std::vector args; typedef std::function OpConverter; -struct ConversionPattern { +struct ConversionPattern { std::string signature; OpConverter converter; }; @@ -107,7 +107,7 @@ struct Weights { Weights(); Weights(ConversionCtx* ctx, at::Tensor t); Weights(ConversionCtx* ctx, float val); - friend std::ostream& operator<<(std::ostream& os, const Weights& w); + friend std::ostream& operator<<(std::ostream& os, const Weights& w); }; inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) { diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp index ea06ad05a4..e77b3e7a50 100644 --- a/core/conversion/converters/impl/activation.cpp +++ b/core/conversion/converters/impl/activation.cpp @@ -18,11 +18,9 @@ namespace { "Unable to create " #act " layer from node: " << *n); \ \ new_layer->setName(util::node_info(n).c_str()); \ - auto out_value = n->outputs()[0]; \ - auto out_tensor = new_layer->getOutput(0); \ - out_tensor->setName(out_value->debugName().c_str()); \ - ctx->value_tensor_map[out_value] = out_tensor; \ - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); \ + ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); \ + LOG_DEBUG("Output tensor shape: " \ + << new_layer->getOutput(0)->getDimensions()); \ \ return true; \ } \ @@ -30,17 +28,64 @@ namespace { auto act##_registrations TRTORCH_UNUSED = \ RegisterNodeConversionPatterns() \ .pattern({"aten::" #act "(Tensor input) -> (Tensor)", \ - [](ConversionCtx *ctx, const torch::jit::Node *n, \ - args &args) -> bool { return act(ctx, n, args); }}) \ + [](ConversionCtx* ctx, const torch::jit::Node* n, \ + args& args) -> bool { return act(ctx, n, args); }}) \ .pattern({"aten::" #act "_(Tensor(a!) self) -> (Tensor(a!))", \ - [](ConversionCtx *ctx, const torch::jit::Node *n, \ - args &args) -> bool { return act(ctx, n, args); }}); + [](ConversionCtx* ctx, const torch::jit::Node* n, \ + args& args) -> bool { return act(ctx, n, args); }}); + +//TODO: remove support for conversion of implace operators and move to the functionalization pass convert(relu, kRELU); convert(sigmoid, kSIGMOID); convert(tanh, kTANH); #undef convert + +auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns() + .pattern({ + "aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto min = args[1].unwrapToDouble(); + auto max = args[2].unwrapToDouble(); + + auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP); + TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh"); + + new_layer->setAlpha(min); + new_layer->setBeta(max); + + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + //TODO: Remove after functionalization + "aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto min = args[1].unwrapToDouble(); + auto max = args[2].unwrapToDouble(); + + auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP); + TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh"); + + new_layer->setAlpha(min); + new_layer->setBeta(max); + + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }); + + + } // namespace } // namespace impl } // namespace converters diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index 56649c8687..f1ee48d59b 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -31,13 +31,12 @@ bool ConvertConvBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& a auto bias = Weights(ctx, b); auto bn_as_conv = ctx->net->addConvolutionNd(*input, weights.num_output_maps, weights.kernel_shape, weights.data, bias.data); - + TRTORCH_CHECK(bn_as_conv, "Unable to create fused batch norm from node: " << *n); + bn_as_conv->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = bn_as_conv->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + auto bn_out = ctx->AssociateValueAndTensor(n->outputs()[0], bn_as_conv->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << bn_out->getDimensions()); return true; } @@ -68,26 +67,25 @@ bool ConvertLinearBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& auto bn_biased_out = bn_biased->getOutput(0); bn_biased->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - bn_biased_out->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = bn_biased_out; + ctx->AssociateValueAndTensor(n->outputs()[0], bn_biased_out); + return true; } volatile auto batch_norm_registrations = RegisterNodeConversionPatterns() .pattern({ - R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, - Tensor? mean, Tensor? var, + R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, + Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto input = args[0].ITensor(); auto shape = input->getDimensions(); auto gamma = args[1].unwrapToTensor(); - + if (/*training*/ args[5].unwrapToBool()) { LOG_WARNING("TensorRT only converts forward pass of graphs, but saw training = True, may see undefined behavior, consider placing module in eval mode"); } - + // If gamma is None this fails if (util::volume(shape) == gamma.numel()) { return ConvertLinearBatchNorm(ctx, n, args); @@ -103,4 +101,4 @@ volatile auto batch_norm_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // namespace trtorch +} // namespace trtorch diff --git a/core/conversion/converters/impl/constant.cpp b/core/conversion/converters/impl/constant.cpp index 75f8bf98b0..432eb6bf85 100644 --- a/core/conversion/converters/impl/constant.cpp +++ b/core/conversion/converters/impl/constant.cpp @@ -19,12 +19,10 @@ auto constant_registrations = RegisterNodeConversionPatterns() auto t_weights = Weights(ctx, t); auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); const_layer->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = const_layer->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], const_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << const_out->getDimensions()); + return true; } }); @@ -33,5 +31,5 @@ auto constant_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // namespace trtorch +} // namespace trtorch diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index bf5e678ee8..2defe97efd 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -9,14 +9,14 @@ namespace impl { namespace { auto conv_registrations = RegisterNodeConversionPatterns() .pattern({ - R"SIG(aten::_convolution(Tensor input, Tensor weight, + R"SIG(aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, - int[] dilation, bool transposed, - int[] output_padding, int groups, bool benchmark, + int[] dilation, bool transposed, + int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor))SIG", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensor(); - + auto w = Weights(ctx, args[1].unwrapToTensor()); auto stride = util::toDimsHW(args[3].unwrapToIntList()); LOG_DEBUG("stride: " << stride); @@ -27,11 +27,11 @@ auto conv_registrations = RegisterNodeConversionPatterns() bool transposed = args[6].unwrapToBool(); auto out_padding = util::toDimsHW(args[7].unwrapToIntList()); LOG_DEBUG("out_padding: " << out_padding); - int64_t groups = args[8].unwrapToInt(); - + int64_t groups = args[8].unwrapToInt(); + nvinfer1::ILayer* new_layer; if (transposed) { - //TODO: Check deconv correctness + //TODO: Check deconv correctness LOG_WARNING(ctx->logger, "Deconvolution converter has not be tested"); nvinfer1::IDeconvolutionLayer* deconv; if (args[2].IValue()->isTensor()) { @@ -54,9 +54,9 @@ auto conv_registrations = RegisterNodeConversionPatterns() } else { conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, Weights().data); } - + TRTORCH_CHECK(conv, "Unable to create convolution layer from node: " << *n); - + conv->setStrideNd(stride); conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); conv->setPaddingNd(padding); @@ -65,13 +65,11 @@ auto conv_registrations = RegisterNodeConversionPatterns() conv->setNbGroups(groups); new_layer = conv; } - new_layer->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = new_layer->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } @@ -81,4 +79,4 @@ auto conv_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // trtorch diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index c406ce2469..ee0e60a9d7 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -13,13 +13,13 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera auto other_dims = other->getDimensions(); TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims); - + nvinfer1::ILayer* ele; if (scalar != 1) { LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform"); auto shape = util::toVec(other_dims); - + if (shape.size() < 4) { auto new_shape = util::toDimsPad(shape, 4); LOG_DEBUG("Input shape is less than 4D got: " << util::toDims(shape) << ", inserting shuffle layers to reshape to 4D tensor shape: " << new_shape); @@ -33,7 +33,7 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera self_shuffle->setName(std::string("[Reshape self to " + util::toStr(new_shape) + ']').c_str()); self = self_shuffle->getOutput(0); } - + auto scale = Weights(ctx, scalar); auto scaled = ctx->net->addScaleNd(*other, nvinfer1::ScaleMode::kUNIFORM, {}, scale.data, {}, 0); auto scaled_other = scaled->getOutput(0); @@ -45,48 +45,49 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera // shuffle->setName(std::string("[Reshape other to " + util::toStr(util::toDims(shape)) + ']').c_str()); // scaled_other = shuffle->getOutput(0); // } - + ele = ctx->net->addElementWise(*self, *scaled_other, op); } else { ele = ctx->net->addElementWise(*self, *other, op); } + return ele; - + } auto element_wise_registrations = RegisterNodeConversionPatterns() .pattern({ "aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // Should implement self + alpha * other + // Should implement self + alpha * other auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + + TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); + add->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = add->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], add->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } }).pattern({ "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // Should implement self + alpha * other + // Should implement self + alpha * other auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + + TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); + add->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = add->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], add->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } }).pattern({ @@ -97,53 +98,84 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, scalar); + + TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n); + sub->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = sub->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } }).pattern({ - "aten::div(Tensor self, Tensor other) -> Tensor", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // Should implement self / other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); - div->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = div->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Should implement self / other + auto self = args[0].ITensor(); + auto other = args[1].ITensor(); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + + TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); + + div->setName(util::node_info(n).c_str()); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } }).pattern({ - "aten::mul(Tensor self, Tensor other) -> Tensor", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // Should implement self * other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); - mul->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = mul->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // TODO: Remove with functionalization + auto self = args[0].ITensor(); + auto other = args[1].ITensor(); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + + TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); + + div->setName(util::node_info(n).c_str()); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + } + }).pattern({ + "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Should implement self * other + auto self = args[0].ITensor(); + auto other = args[1].ITensor(); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + + TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); + + mul->setName(util::node_info(n).c_str()); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + } + }).pattern({ + "aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // TODO: Remove with functionalization + auto self = args[0].ITensor(); + auto other = args[1].ITensor(); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + + TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); + + mul->setName(util::node_info(n).c_str()); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; } }); - - + } // namespace } // namespace impl } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // trtorch diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index f1b6aa64e3..0b5b4e957a 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -22,13 +22,13 @@ auto linear_registrations = RegisterNodeConversionPatterns() TRTORCH_ASSERT(shape.size() >= 2, "aten::linear expects input tensors to be of shape [N,..., in features], but found input Tensor less than 2D"); if (shape.size() < 4) { - // Flatten + // Flatten std::vector new_shape; new_shape.push_back(shape[0]); new_shape.push_back(1); new_shape.push_back(1); - new_shape.push_back(util::volume(util::toDims(shape))); - + new_shape.push_back(util::volume(util::toDims(shape)) / shape[0]); + auto new_dims = util::toDims(new_shape); LOG_DEBUG("Input shape is less than 4D got: " << util::toDims(shape) << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_dims); auto in_shuffle = ctx->net->addShuffle(*in); @@ -36,7 +36,7 @@ auto linear_registrations = RegisterNodeConversionPatterns() in_shuffle->setName((util::node_info(n) + " [Input Reshape to " + util::toStr(new_dims) + ']').c_str()); in = in_shuffle->getOutput(0); } - + auto w_tensor = args[1].IValue()->toTensor(); Weights w = Weights(ctx, w_tensor); @@ -50,13 +50,10 @@ auto linear_registrations = RegisterNodeConversionPatterns() } TRTORCH_CHECK(new_layer,"Unable to create linear layer from node: " << *n); - + new_layer->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = new_layer->getOutput(0); - - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; @@ -67,4 +64,4 @@ auto linear_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // trtorch diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index 7efca913e9..85a4acf5bf 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -15,7 +15,7 @@ auto pooling_registrations = RegisterNodeConversionPatterns() auto in = args[0].ITensor(); auto shape = util::toVec(in->getDimensions()); - // Max Pool needs at least 4D input + // Max Pool needs at least 4D input if (shape.size() < 4) { auto new_shape = util::toDimsPad(shape, 4); LOG_DEBUG("Input shape is less than 4D got: " << util::toDims(shape) << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_shape); @@ -24,16 +24,16 @@ auto pooling_registrations = RegisterNodeConversionPatterns() shuffle->setName((util::node_info(n) + " [Reshape to " + util::toStr(new_shape) + ']').c_str()); in = shuffle->getOutput(0); } - - + + auto kernel_size = util::toDimsHW(args[1].unwrapToIntList()); LOG_DEBUG("kernel_size: " << kernel_size); auto padding = util::toDimsHW(args[3].unwrapToIntList()); LOG_DEBUG("padding: " << padding); auto dilation = util::toDims(args[4].unwrapToIntList()); - + TRTORCH_ASSERT(dilation == util::toDims(std::vector({1,1})), "Pooling dilation is not supported in TensorRT"); - + LOG_DEBUG("dilation: " << dilation); LOG_WARNING("Dilation not used in max pooling converter"); bool ceil_mode = args[5].IValue()->to(); @@ -47,17 +47,14 @@ auto pooling_registrations = RegisterNodeConversionPatterns() auto stride = util::toDims(args[2].unwrapToIntList()); new_layer->setStrideNd(stride); } - + auto padding_mode = ceil_mode ? nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP : nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN; new_layer->setPaddingMode(padding_mode); - + new_layer->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = new_layer->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; } }).pattern({ @@ -75,7 +72,7 @@ auto pooling_registrations = RegisterNodeConversionPatterns() in = shuffle->getOutput(0); in_shape = util::toVec(in->getDimensions()); } - + auto out_shape = args[1].IValue()->toIntList(); std::vector stride(out_shape.size()); @@ -90,7 +87,7 @@ auto pooling_registrations = RegisterNodeConversionPatterns() } LOG_DEBUG("Window" << util::toDims(window)); - + auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window)); if (!new_layer) { LOG_ERROR("Unable to create average pooling layer from node: " << *n); @@ -98,14 +95,11 @@ auto pooling_registrations = RegisterNodeConversionPatterns() } new_layer->setStrideNd(util::toDims(stride)); + new_layer->setName(util::node_info(n).c_str()); - auto out_value = n->outputs()[0]; - auto out_tensor = new_layer->getOutput(0); - out_tensor->setName(out_value->debugName().c_str()); - ctx->value_tensor_map[out_value] = out_tensor; - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; } }); @@ -114,4 +108,4 @@ auto pooling_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // trtorch diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp new file mode 100644 index 0000000000..0127f83285 --- /dev/null +++ b/core/conversion/converters/impl/reduce.cpp @@ -0,0 +1,192 @@ +#include +#include "core/util/prelude.h" +#include "core/conversion/converters/converters.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + + + +auto reduce_registrations = RegisterNodeConversionPatterns() + .pattern({ + "aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + LOG_WARNING("Mean Converter disregards dtype"); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, false); + + TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n); + + mean_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto dims = args[1].unwrapToIntList(); + LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info + + uint32_t axis_mask = 0; + for (size_t d = 0; d < dims.size(); d++) { + axis_mask |= 1 << dims[d]; + } + LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask)); + + auto keepdim = args[2].unwrapToBool(); + LOG_DEBUG("Keep dims :" << keepdim); + + LOG_WARNING("Mean converter disregards dtype"); + auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim); + + TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n); + + mean_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + LOG_WARNING("Sum Converter disregards dtype"); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, false); + + TRTORCH_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], sum_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto dims = args[1].unwrapToIntList(); + LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info + + uint32_t axis_mask = 0; + for (size_t d = 0; d < dims.size(); d++) { + axis_mask |= 1 << dims[d]; + } + LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask)); + + auto keepdim = args[2].unwrapToBool(); + LOG_DEBUG("Keep dims :" << keepdim); + + LOG_WARNING("Sum converter disregards dtype"); + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); + + TRTORCH_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], sum_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + LOG_WARNING("Prod Converter disregards dtype"); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto prod_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kPROD, axis_mask, false); + + TRTORCH_CHECK(prod_layer, "Unable to create sum layer from node: " << *n); + + prod_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], prod_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto dim = args[1].unwrapToInt(); + LOG_DEBUG("Dim to reduce:" << dim); // Some abuse of toDim but just for debug info + + uint32_t axis_mask = 1 << dim; + LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask)); + + auto keepdim = args[2].unwrapToBool(); + LOG_DEBUG("Keep dims :" << keepdim); + + LOG_WARNING("Prod converter disregards dtype"); + auto prod_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kPROD, axis_mask, keepdim); + + TRTORCH_CHECK(prod_layer, "Unable to create mean layer from node: " << *n); + + prod_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], prod_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::max(Tensor self) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto max_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kMAX, axis_mask, false); + + TRTORCH_CHECK(max_layer, "Unable to create max layer from node: " << *n); + + max_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::min(Tensor self) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto min_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kMIN, axis_mask, false); + + TRTORCH_CHECK(min_layer, "Unable to create min layer from node: " << *n); + + min_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], min_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/softmax.cpp b/core/conversion/converters/impl/softmax.cpp index d52b655bd2..c138c759f0 100644 --- a/core/conversion/converters/impl/softmax.cpp +++ b/core/conversion/converters/impl/softmax.cpp @@ -15,7 +15,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() auto in = args[0].ITensor(); auto shape = util::toVec(in->getDimensions()); - // SoftMax needs at least 4D input + // SoftMax needs at least 4D input if (shape.size() < 2) { auto new_shape = util::toDimsPad(shape, 2); LOG_DEBUG("Input shape is less than 2D got: " << util::toDims(shape) << ", inserting shuffle layer to reshape to 2D tensor shape: " << new_shape); @@ -24,9 +24,12 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() shuffle->setName((util::node_info(n) + " [Reshape to " + util::toStr(new_shape) + ']').c_str()); in = shuffle->getOutput(0); } - + int64_t dim = args[1].IValue()->toInt(); auto softmax = ctx->net->addSoftMax(*in); + + TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n); + if (!softmax) { LOG_ERROR("Unable to create softmax layer from node: " << *n); return false; @@ -39,7 +42,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() // When there is no batch dimension softmax->setAxes(1 << (dim + 1)); } - + softmax->setName(util::node_info(n).c_str()); auto out_value = n->outputs()[0]; auto out_tensor = softmax->getOutput(0); @@ -57,7 +60,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() out_tensor->setName(out_value->debugName().c_str()); ctx->value_tensor_map[out_value] = out_tensor; LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - + return true; } }); @@ -66,4 +69,4 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // trtorch diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index 2bc39850dc..dc4e8764c1 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -8,29 +8,28 @@ namespace converters { namespace impl { namespace { -#define convert(unary, trt_type) \ - auto unary##_registrations TRTORCH_UNUSED = \ - RegisterNodeConversionPatterns().pattern( \ - {"aten::" #unary "(Tensor self) -> Tensor", \ - [](ConversionCtx *ctx, const torch::jit::Node *n, \ - args &args) -> bool { \ - auto in = args[0].ITensor(); \ - auto unary = \ - ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \ - \ - TRTORCH_CHECK( \ - unary, \ - "Unable to create " #unary " layer from node: " << *n); \ - \ - unary->setName(util::node_info(n).c_str()); \ - auto out_value = n->outputs()[0]; \ - auto out_tensor = unary->getOutput(0); \ - out_tensor->setName(out_value->debugName().c_str()); \ - ctx->value_tensor_map[out_value] = out_tensor; \ - LOG_DEBUG( \ - "Output tensor shape: " << out_tensor->getDimensions()); \ - \ - return true; \ +#define convert(unary, trt_type) \ + auto unary##_registrations TRTORCH_UNUSED = \ + RegisterNodeConversionPatterns().pattern( \ + {"aten::" #unary "(Tensor self) -> Tensor", \ + [](ConversionCtx *ctx, const torch::jit::Node *n, \ + args &args) -> bool { \ + auto in = args[0].ITensor(); \ + auto unary = \ + ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \ + \ + TRTORCH_CHECK( \ + unary, \ + "Unable to create " #unary " layer from node: " << *n); \ + \ + unary->setName(util::node_info(n).c_str()); \ + auto out_tensor = ctx->AssociateValueAndTensor( \ + n->outputs()[0], \ + unary->getOutput(0)); \ + LOG_DEBUG( \ + "Output tensor shape: " << out_tensor->getDimensions()); \ + \ + return true; \ }}); convert(cos, kCOS); diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000..d5b1d11356 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,15 @@ +# Tests + +Right now there are two types of tests. Converter level tests and Module level tests. + +The goal of Converter tests are to tests individual converters againsts specific subgraphs. The current tests in `core/conveters` are good examples on how to write these tests. In general every converter should have at least 1 test. More may be required if the operation has switches that change the behavior of the op. + +Module tests are designed to test the compiler against common network architectures and verify the integration of converters together into a single engine. + +You can run the whole test suite with bazel. But be aware you may exhaust GPU memory (this may be seen as a cuDNN initialization error) running them naively, you therefore may need to limit the number of concurrent tests. Also because the inputs to tests are random it may make sense to run tests a few times. + +Here are some settings that work well the current test suite on a TITAN V. + +``` +bazel test //tests --compilation_mode=dbg --test_output=errors --jobs=4 --runs_per_test=5 +``` diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 95af88cfc6..8cbe9c4e68 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -28,6 +28,10 @@ converter_test( name = "test_conv" ) +converter_test( + name = "test_reduce" +) + test_suite( name = "test_converters", tests = [ @@ -38,6 +42,7 @@ test_suite( ":test_linear", ":test_element_wise", ":test_conv", + ":test_reduce" ] ) diff --git a/tests/core/converters/test_activation.cpp b/tests/core/converters/test_activation.cpp index 64a1589282..bae82a4fe3 100644 --- a/tests/core/converters/test_activation.cpp +++ b/tests/core/converters/test_activation.cpp @@ -21,7 +21,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenSigmoidConvertsCorrectly) { @@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenTanhConvertsCorrectly) { @@ -61,5 +61,51 @@ TEST(Converters, ATenTanhConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } + +//TODO: Seems like the IR parser is not handling negative numbers well, need to follow up with the PyTorch Team +// TEST(Converters, ATenHardTanhConvertsCorrectly) { +// const auto graph = R"IR( +// graph(%0 : Tensor): +// %1 : float = prim::Constant[value=-1.0]() +// %2 : float = prim::Constant[value=1.0]() +// %3 : Tensor = aten::hardtanh(%0, %1, %2) +// return (%3))IR"; + +// auto g = std::make_shared(); +// torch::jit::script::parseIR(graph, &*g); + +// auto in = at::randint(-5, 5, {5}, {at::kCUDA}); +// auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); +// auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + +// in = at::clone(in); +// params = trtorch::core::conversion::get_named_params(g->inputs(), {}); +// auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + +// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +// } + +TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : float = prim::Constant[value=0.0]() + %2 : float = prim::Constant[value=6.0]() + %3 : Tensor = aten::hardtanh(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::script::parseIR(graph, &*g); + + auto in = at::randint(-5, 5, {5}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + diff --git a/tests/core/converters/test_conv.cpp b/tests/core/converters/test_conv.cpp index 9dc1c362bd..69f040bb43 100644 --- a/tests/core/converters/test_conv.cpp +++ b/tests/core/converters/test_conv.cpp @@ -4,16 +4,16 @@ #include "tests/util/util.h" #include "core/compiler.h" -// aten::_convolution(Tensor input, Tensor weight, +// aten::_convolution(Tensor input, Tensor weight, // Tensor? bias, int[] stride, int[] padding, -// int[] dilation, bool transposed, -// int[] output_padding, int groups, bool benchmark, +// int[] dilation, bool transposed, +// int[] output_padding, int groups, bool benchmark, // bool deterministic, bool cudnn_enabled) -> (Tensor) void conv_test_helper(std::string graph_ir) { auto g = std::make_shared(); torch::jit::script::parseIR(graph_ir, &*g); - + auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA}); auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA}); auto b = at::randint(1, 10, {8}, {at::kCUDA}); @@ -21,7 +21,7 @@ void conv_test_helper(std::string graph_ir) { auto jit_in = at::clone(in); auto jit_w = at::clone(w); auto jit_b = at::clone(b); - + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); @@ -31,11 +31,11 @@ void conv_test_helper(std::string graph_ir) { params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } - + TEST(Converters, ATenConvolutionConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, @@ -45,7 +45,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) { %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=0]() - %7 : bool = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() %8 : int[] = prim::ListConstruct(%3, %3) %9 : int[] = prim::ListConstruct(%4, %4) %10 : int[] = prim::ListConstruct(%5, %5) @@ -55,7 +55,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA}); auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA}); auto b = at::randint(1, 10, {8}, {at::kCUDA}); @@ -63,7 +63,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) { auto jit_in = at::clone(in); auto jit_w = at::clone(w); auto jit_b = at::clone(b); - + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); @@ -73,9 +73,9 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { @@ -87,7 +87,7 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=0]() - %7 : bool = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() %8 : int[] = prim::ListConstruct(%3, %3) %9 : int[] = prim::ListConstruct(%4, %4) %10 : int[] = prim::ListConstruct(%5, %5) @@ -97,12 +97,12 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(1, 2, {1, 1, 3, 3}, {at::kCUDA}); auto w = at::randint(1, 2, {4, 1, 2, 2}, {at::kCUDA}); auto jit_in = at::clone(in); - auto jit_w = at::clone(w); + auto jit_w = at::clone(w); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); @@ -111,9 +111,9 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } @@ -126,7 +126,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) { %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=0]() - %7 : bool = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() %8 : int[] = prim::ListConstruct(%3, %3) %9 : int[] = prim::ListConstruct(%4, %4) %10 : int[] = prim::ListConstruct(%5, %5) @@ -137,7 +137,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(1, 10, {1, 3, 9, 9}, {at::kCUDA}); auto w = at::randint(1, 10, {4, 3, 3, 3}, {at::kCUDA}); auto b = at::randint(1, 10, {4}, {at::kCUDA}); @@ -145,7 +145,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) { auto jit_in = at::clone(in); auto jit_w = at::clone(w); auto jit_b = at::clone(b); - + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); @@ -155,9 +155,9 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { @@ -169,7 +169,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { %4 : int = prim::Constant[value=2]() %5 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=0]() - %7 : bool = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() %8 : int[] = prim::ListConstruct(%3, %3) %9 : int[] = prim::ListConstruct(%4, %4) %10 : int[] = prim::ListConstruct(%5, %5) @@ -180,7 +180,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA}); auto b = at::randint(1, 10, {4}, {at::kCUDA}); @@ -188,7 +188,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { auto jit_in = at::clone(in); auto jit_w = at::clone(w); auto jit_b = at::clone(b); - + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); @@ -198,9 +198,9 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } // TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) { @@ -212,7 +212,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { // %4 : int = prim::Constant[value=0]() // %5 : int = prim::Constant[value=2]() // %6 : int = prim::Constant[value=0]() -// %7 : bool = prim::Constant[value=0]() +// %7 : bool = prim::Constant[value=0]() // %8 : int[] = prim::ListConstruct(%3, %3) // %9 : int[] = prim::ListConstruct(%4, %4) // %10 : int[] = prim::ListConstruct(%5, %5) @@ -233,7 +233,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { // %4 : int = prim::Constant[value=0]() // %5 : int = prim::Constant[value=1]() // %6 : int = prim::Constant[value=2]() -// %7 : bool = prim::Constant[value=0]() +// %7 : bool = prim::Constant[value=0]() // %8 : int[] = prim::ListConstruct(%3, %3) // %9 : int[] = prim::ListConstruct(%4, %4) // %10 : int[] = prim::ListConstruct(%5, %5) @@ -254,7 +254,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { // %4 : int = prim::Constant[value=0]() // %5 : int = prim::Constant[value=1]() // %6 : int = prim::Constant[value=0]() -// %7 : bool = prim::Constant[value=0]() +// %7 : bool = prim::Constant[value=0]() // %8 : int[] = prim::ListConstruct(%3, %3) // %9 : int[] = prim::ListConstruct(%4, %4) // %10 : int[] = prim::ListConstruct(%5, %5) @@ -262,6 +262,6 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { // %12 : int = prim::Constant[value=2]() // %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7) // return (%13))IR"; - + // conv_test_helper(graph); // } diff --git a/tests/core/converters/test_element_wise.cpp b/tests/core/converters/test_element_wise.cpp index a5c66668ea..db8ea15870 100644 --- a/tests/core/converters/test_element_wise.cpp +++ b/tests/core/converters/test_element_wise.cpp @@ -7,7 +7,7 @@ void pointwise_test_helper(std::string graph_ir) { auto g = std::make_shared(); torch::jit::script::parseIR(graph_ir, &*g); - + auto in0 = at::randint(1, 5, {5}, {at::kCUDA}); auto in1 = at::randint(1, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -18,7 +18,7 @@ void pointwise_test_helper(std::string graph_ir) { params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } diff --git a/tests/core/converters/test_linear.cpp b/tests/core/converters/test_linear.cpp index 65a60446a4..f6d543ab04 100644 --- a/tests/core/converters/test_linear.cpp +++ b/tests/core/converters/test_linear.cpp @@ -4,7 +4,6 @@ #include "tests/util/util.h" #include "core/compiler.h" - TEST(Converters, ATenLinearNoBiasConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, @@ -28,7 +27,7 @@ TEST(Converters, ATenLinearNoBiasConvertsCorrectly) { params = trtorch::core::conversion::get_named_params(g->inputs(), {w}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]))); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } @@ -62,5 +61,5 @@ TEST(Converters, ATenLinearBiasConvertsCorrectly) { auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]))); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } diff --git a/tests/core/converters/test_pooling.cpp b/tests/core/converters/test_pooling.cpp index 1b13e739df..f278d8b65f 100644 --- a/tests/core/converters/test_pooling.cpp +++ b/tests/core/converters/test_pooling.cpp @@ -10,17 +10,17 @@ TEST(Converters, ATenMaxPool2DConvertsCorrectly) { %1 : int = prim::Constant[value=0]() %2 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=2]() - %5 : bool = prim::Constant[value=0]() + %5 : bool = prim::Constant[value=0]() %6 : int[] = prim::ListConstruct(%1, %1) %7 : int[] = prim::ListConstruct(%2, %2) %8 : int[] = prim::ListConstruct(%3, %3) - %10 : Tensor = aten::max_pool2d(%0, %8, %7, %6, %7, %5) + %10 : Tensor = aten::max_pool2d(%0, %8, %7, %6, %7, %5) return (%10))IR"; auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - //PyTorch MaxPool needs a 3D input + //PyTorch MaxPool needs a 3D input auto in = at::randint(-5, 5, {1, 4, 4}, at::kCUDA); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); @@ -28,8 +28,8 @@ TEST(Converters, ATenMaxPool2DConvertsCorrectly) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { @@ -38,13 +38,13 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { %2 : int = prim::Constant[value=3]() %3 : int = prim::Constant[value=4]() %6 : int[] = prim::ListConstruct(%2, %3) - %10 : Tensor = aten::adaptive_avg_pool2d(%0, %6) + %10 : Tensor = aten::adaptive_avg_pool2d(%0, %6) return (%10))IR"; auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - //PyTorch MaxPool needs a 3D input + //PyTorch MaxPool needs a 3D input auto in = at::randint(-5, 5, {1, 12, 16}, at::kCUDA); auto jit_in = at::clone(in); @@ -54,6 +54,6 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { auto trt_in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } diff --git a/tests/core/converters/test_reduce.cpp b/tests/core/converters/test_reduce.cpp new file mode 100644 index 0000000000..81d6e6c606 --- /dev/null +++ b/tests/core/converters/test_reduce.cpp @@ -0,0 +1,154 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +namespace { +std::string gen_basic_graph(const std::string& op) { + return R"IR( + graph(%0 : Tensor): + %4 : None = prim::Constant() + %5 : Tensor = aten::)IR" + op + R"IR((%0, %4) + return (%5))IR"; +} + +std::string gen_min_max_graph(const std::string& op) { + return R"IR( + graph(%0 : Tensor): + %5 : Tensor = aten::)IR" + op + R"IR((%0) + return (%5))IR"; +} + +std::string gen_dim_graph(const std::string& op) { + return R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=0]() + %4 : None = prim::Constant() + %5 : Tensor = aten::)IR" + op + R"IR((%0, %2, %3, %4) + return (%5))IR"; +} + +std::string gen_multidim_graph(const std::string& op) { + return R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : bool = prim::Constant[value=0]() + %5 : None = prim::Constant() + %6 : Tensor = aten::)IR" + op + R"IR((%0, %3, %4, %5) + return (%6))IR"; +} + +std::string gen_keepdim_graph(const std::string& op) { + return R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::)IR" + op + R"IR((%0, %2, %3, %4) + return (%5))IR"; +} + +void test_body(const std::string& graph, at::Tensor& in) { + auto g = std::make_shared(); + torch::jit::script::parseIR(graph, &*g); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} +} + +#define converts_correctly(op, name) \ + TEST(Converters, ATen##name##ConvertsCorrectly) { \ + const auto graph = gen_basic_graph(#op); \ + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); \ + test_body(graph, in); \ + } + +converts_correctly(sum, Sum); +converts_correctly(prod, Prod); +converts_correctly(mean, Mean); + +#undef converts_correctly + +#define min_max_converts_correctly(op, name) \ + TEST(Converters, ATen##name##ConvertsCorrectly) { \ + const auto graph = gen_min_max_graph(#op); \ + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); \ + test_body(graph, in); \ + } + +min_max_converts_correctly(max, Max); +min_max_converts_correctly(min, Min); + +#undef min_max_converts_correctly + +#define converts_dim_correctly(op, name) \ + TEST(Converters, ATen##name##DimConvertsCorrectly) { \ + const auto graph = gen_dim_graph(#op); \ + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); \ + test_body(graph, in); \ + } + +converts_dim_correctly(sum, Sum); +converts_dim_correctly(mean, Mean); + +#undef converts_dim_correctly + +#define converts_multidims_correctly(op, name) \ + TEST(Converters, ATen##name##MultiDimsConvertsCorrectly) { \ + const auto graph = gen_multidim_graph(#op); \ + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); \ + test_body(graph, in); \ + } + +converts_multidims_correctly(sum, Sum); +converts_multidims_correctly(mean, Mean); + +#undef converts_multidims_correctly + +#define converts_keepdims_correctly(op, name) \ +TEST(Converters, ATen##name##KeepDimsConvertsCorrectly) { \ + const auto graph = gen_keepdim_graph(#op); \ + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); \ + test_body(graph, in); \ +} + +converts_keepdims_correctly(sum, Sum); +converts_keepdims_correctly(mean, Mean); + +#undef converts_keepdims_correctly + +TEST(Converters, ATenProdDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %4 : None = prim::Constant() + %5 : Tensor = aten::prod(%0, %1, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenProdKeepDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::prod(%0, %1, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); + test_body(graph, in); +} \ No newline at end of file diff --git a/tests/core/converters/test_softmax.cpp b/tests/core/converters/test_softmax.cpp index 7696c863ba..5f8186021c 100644 --- a/tests/core/converters/test_softmax.cpp +++ b/tests/core/converters/test_softmax.cpp @@ -14,7 +14,7 @@ TEST(Converters, ATenSoftmax1DConvertsCorrectly) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(0, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); @@ -22,9 +22,9 @@ TEST(Converters, ATenSoftmax1DConvertsCorrectly) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - auto trt = trt_results[0].reshape_as(jit_results[0]); + auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } TEST(Converters, ATenSoftmaxNDConvertsCorrectlySub3DIndex) { @@ -37,7 +37,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlySub3DIndex) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(0, 5, {1,2,2,2,2}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -46,9 +46,9 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlySub3DIndex) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - auto trt = trt_results[0].reshape_as(jit_results[0]); + auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) { @@ -61,7 +61,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) { auto g = std::make_shared(); torch::jit::script::parseIR(graph, &*g); - + auto in = at::randint(0, 5, {1,2,2,2,2}, {at::kCUDA}); auto jit_in = at::clone(in); @@ -71,8 +71,8 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) { auto trt_in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt)); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } diff --git a/tests/core/converters/test_unary.cpp b/tests/core/converters/test_unary.cpp index 00c5599236..a82a4ade9f 100644 --- a/tests/core/converters/test_unary.cpp +++ b/tests/core/converters/test_unary.cpp @@ -14,24 +14,24 @@ std::string gen_test_graph(const std::string &unary) { } } // namespace -#define test_unary(unary, name) \ - TEST(Converters, ATen##name##ConvertsCorrectly) { \ - const auto graph = gen_test_graph(#unary); \ - \ - auto g = std::make_shared(); \ - torch::jit::script::parseIR(graph, &*g); \ - \ - auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \ - auto params = \ - trtorch::core::conversion::get_named_params(g->inputs(), {}); \ - auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \ - \ - in = at::clone(in); \ - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); \ - \ - ASSERT_TRUE( \ - trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); \ +#define test_unary(unary, name) \ + TEST(Converters, ATen##name##ConvertsCorrectly) { \ + const auto graph = gen_test_graph(#unary); \ + \ + auto g = std::make_shared(); \ + torch::jit::script::parseIR(graph, &*g); \ + \ + auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \ + auto params = \ + trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \ + \ + in = at::clone(in); \ + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); \ + \ + ASSERT_TRUE( \ + trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); \ } test_unary(cos, Cos); diff --git a/tests/modules/test_compiled_modules.cpp b/tests/modules/test_compiled_modules.cpp index 571290233a..5f817dc34d 100644 --- a/tests/modules/test_compiled_modules.cpp +++ b/tests/modules/test_compiled_modules.cpp @@ -13,14 +13,14 @@ TEST_P(ModuleTests, CompiledModuleIsClose) { std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); - + auto trt_mod = trtorch::CompileGraph(mod, input_shapes); torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); std::vector trt_results; trt_results.push_back(trt_results_ivalues.toTensor()); for (size_t i = 0; i < trt_results.size(); i++) { - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]))); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5)); } } @@ -33,4 +33,6 @@ INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite, PathAndInSize({"tests/modules/resnet18.jit.pt", {{1,3,224,224}}}), PathAndInSize({"tests/modules/resnet50.jit.pt", + {{1,3,224,224}}}), + PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", {{1,3,224,224}}}))); diff --git a/tests/modules/test_modules_as_engines.cpp b/tests/modules/test_modules_as_engines.cpp index 04f1c33ac1..3b7287c6c8 100644 --- a/tests/modules/test_modules_as_engines.cpp +++ b/tests/modules/test_modules_as_engines.cpp @@ -12,8 +12,8 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) { std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); auto trt_results = trtorch::tests::util::RunModuleForwardAsEngine(mod, inputs); - - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]))); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5)); } INSTANTIATE_TEST_SUITE_P(ModuleAsEngineForwardIsCloseSuite, @@ -24,4 +24,6 @@ INSTANTIATE_TEST_SUITE_P(ModuleAsEngineForwardIsCloseSuite, PathAndInSize({"tests/modules/resnet18.jit.pt", {{1,3,224,224}}}), PathAndInSize({"tests/modules/resnet50.jit.pt", - {{1,3,224,224}}}))); + {{1,3,224,224}}}), + PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", + {{1,3,224,224}}}))); \ No newline at end of file diff --git a/tests/modules/test_multiple_registered_engines.cpp b/tests/modules/test_multiple_registered_engines.cpp index 8bbd82603a..7ce3dbf61f 100644 --- a/tests/modules/test_multiple_registered_engines.cpp +++ b/tests/modules/test_multiple_registered_engines.cpp @@ -17,7 +17,7 @@ TEST(ModuleTests, CanRunMultipleEngines) { } const std::vector> input_shapes = {{1,3,224,224}}; - + std::vector jit1_inputs_ivalues; std::vector trt1_inputs_ivalues; for (auto in_shape : input_shapes) { @@ -42,7 +42,7 @@ TEST(ModuleTests, CanRunMultipleEngines) { std::vector jit2_results; jit2_results.push_back(jit2_results_ivalues.toTensor()); - + auto trt_mod1 = trtorch::CompileGraph(mod1, input_shapes); torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod1, trt1_inputs_ivalues); std::vector trt1_results; @@ -55,10 +55,10 @@ TEST(ModuleTests, CanRunMultipleEngines) { for (size_t i = 0; i < trt1_results.size(); i++) { - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit1_results[i], trt1_results[i].reshape_as(jit1_results[i]))); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit1_results[i], trt1_results[i].reshape_as(jit1_results[i]), 2e-5)); } for (size_t i = 0; i < trt2_results.size(); i++) { - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit2_results[i], trt2_results[i].reshape_as(jit2_results[i]))); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit2_results[i], trt2_results[i].reshape_as(jit2_results[i]), 2e-5)); } } diff --git a/tests/util/util.cpp b/tests/util/util.cpp index 4a38793bbf..e317029a1e 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -5,18 +5,19 @@ namespace trtorch { namespace tests { namespace util { -bool checkRtol(const at::Tensor& diff, const std::vector inputs) { +bool checkRtol(const at::Tensor& diff, const std::vector inputs, float threshold) { double maxValue = 0.0; for (auto& tensor : inputs) { maxValue = fmax(tensor.abs().max().item(), maxValue); } std::cout << "Max Difference: " << diff.abs().max().item() << std::endl; - return diff.abs().max().item() <= 2e-6 * maxValue; + std::cout << "Acceptable Threshold: " << threshold << std::endl; + return diff.abs().max().item() <= threshold * maxValue; } -bool almostEqual(const at::Tensor& a, const at::Tensor& b) { +bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) { LOG_DEBUG(a << std::endl << b << std::endl); - return checkRtol(a - b, {a, b}); + return checkRtol(a - b, {a, b}, threshold); } bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { diff --git a/tests/util/util.h b/tests/util/util.h index ab386f38bb..fa9008bd39 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -11,13 +11,13 @@ namespace trtorch { namespace tests { namespace util { -bool almostEqual(const at::Tensor& a, const at::Tensor& b); +bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b); std::vector RunEngine(std::string& eng, std::vector inputs); -// Runs an arbitrary JIT graph and returns results +// Runs an arbitrary JIT graph and returns results std::vector RunGraph(std::shared_ptr& g, core::conversion::GraphParams& named_params, std::vector inputs);