|
9 | 9 |
|
10 | 10 | #include "c10/util/intrusive_ptr.h" |
11 | 11 | #include "core/conversion/tensorcontainer/TensorContainer.h" |
| 12 | +#include "core/util/trt_util.h" |
| 13 | +#include "core/conversion/converters/converter_util.h" |
12 | 14 |
|
13 | 15 | namespace trtorch { |
14 | 16 | namespace core { |
@@ -210,6 +212,21 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp |
210 | 212 | LOG_INFO( |
211 | 213 | ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); |
212 | 214 | ctx->num_outputs += 1; |
| 215 | + } else if(out_ivalue.isTuple()) { |
| 216 | + TRTORCH_THROW_ERROR("Tuple type. Only a single tensor or a TensorList type is supported."); |
| 217 | + } else if(out_ivalue.isList()) { |
| 218 | + TRTORCH_THROW_ERROR("List type. Only a single tensor or a TensorList type is supported."); |
| 219 | + } else if(out_ivalue.isScalar()) { |
| 220 | + TRTORCH_THROW_ERROR("Scalar type. Only a single tensor or a TensorList type is supported."); |
| 221 | + } else if(out_ivalue.isTensor()) { |
| 222 | + // prim::NumToTensor will go to here |
| 223 | + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); |
| 224 | + auto out_tensor = trtorch::core::conversion::converters::tensor_to_const(ctx, out_ivalue.toTensor(), ""); |
| 225 | + out_tensor->setName(name.c_str()); |
| 226 | + ctx->net->markOutput(*out_tensor); |
| 227 | + LOG_INFO( |
| 228 | + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); |
| 229 | + ctx->num_outputs += 1; |
213 | 230 | } else { |
214 | 231 | TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); |
215 | 232 | } |
@@ -361,6 +378,7 @@ void ConvertBlockToNetDef( |
361 | 378 | ConversionInfo build_info, |
362 | 379 | GraphParams& static_params) { |
363 | 380 | LOG_INFO(ctx->logger, "Converting Block"); |
| 381 | + LOG_DEBUG(ctx->logger, *b->owningGraph()); |
364 | 382 |
|
365 | 383 | auto inputs = b->inputs(); |
366 | 384 | AddParamsToCtxValueMap(ctx, static_params); |
|
0 commit comments