@@ -190,6 +190,57 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190190 }
191191}
192192
193+ void MapIValues (ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194+ std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195+ std::transform (in_list.begin () + in_offset, in_list.end (), out_list.begin () + out_offset,
196+ std::back_inserter (input_output_pairs),
197+ [](auto in, auto out){
198+ return std::make_pair (in, out);
199+ });
200+
201+ for (auto p : input_output_pairs) {
202+ auto input = ctx->evaluated_value_map [p.first ];
203+ ctx->evaluated_value_map [p.second ] = torch::jit::IValue (input);
204+ }
205+ }
206+
207+ // TODO: With functionalization pass we may be able to make this into a regular evaluator later
208+ void EvaluateLoopBlock (ConversionCtx* ctx, const torch::jit::Node* n) {
209+ auto max_trip_count = ctx->evaluated_value_map [n->input (0 )];
210+ auto start_cond = ctx->evaluated_value_map [n->input (1 )];
211+ ctx->evaluated_value_map [n->blocks ()[0 ]->inputs ()[0 ]] = torch::jit::IValue (0 );
212+ auto trip_count = ctx->evaluated_value_map [n->blocks ()[0 ]->inputs ()[0 ]];
213+
214+ MapIValues (ctx, n->inputs (), n->outputs (), 2 , 0 );
215+
216+ LOG_DEBUG (" (Loop Evaluation) Evaluating loop " << *n);
217+ LOG_DEBUG (" (Loop Evaluation) Max Trip Count: " << max_trip_count.toInt ());
218+ LOG_DEBUG (" (Loop Evaluation) Start Condition: " << start_cond.toBool ());
219+ LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
220+
221+ while (start_cond.toBool () && trip_count.toInt () < max_trip_count.toInt ()) {
222+ MapIValues (ctx, n->outputs (), n->blocks ()[0 ]->inputs (), 0 , 1 );
223+ for (auto bn : n->blocks ()[0 ]->nodes ()) {
224+ auto eval = EvaluateNode (ctx, bn);
225+ if (eval) {
226+ if (!eval.value ().isTensor ()) {
227+ LOG_DEBUG (ctx->logger , " (Loop Evaluation) Found the value to be: " << eval.value ());
228+ } else {
229+ LOG_DEBUG (ctx->logger , " (Loop Evaluation) Found the value to be a tensor (shape " << eval.value ().toTensor ().sizes () << ' )' );
230+ }
231+ ctx->AssociateValueAndIValue (bn->output (0 ), eval.value ());
232+ }
233+ }
234+
235+ MapIValues (ctx, n->blocks ()[0 ]->outputs (), n->outputs (), 1 , 0 );
236+ start_cond = ctx->evaluated_value_map [n->blocks ()[0 ]->outputs ()[0 ]];
237+ auto new_trip_count = torch::jit::IValue (trip_count.toInt () + 1 );
238+ trip_count.swap (new_trip_count);
239+ LOG_DEBUG (" (Loop Evaluation) Condition: " << start_cond.toBool ());
240+ LOG_DEBUG (" (Loop Evaluation) Current Trip Count: " << trip_count.toInt ());
241+ }
242+ }
243+
193244void ConvertBlockToNetDef (ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
194245 LOG_INFO (ctx->logger , " Converting Block" );
195246
@@ -202,7 +253,19 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
202253 for (const auto n : nodes) {
203254 bool to_eval = evaluators::shouldEvalAtConversionTime (n);
204255 bool blacklisted = isNodeConversionBlacklisted (n);
205- if (!to_eval && !blacklisted) {
256+ if (n->kind () == torch::jit::prim::Loop) {
257+ EvaluateLoopBlock (ctx, n);
258+ } else if (to_eval) {
259+ auto eval = EvaluateNode (ctx, n);
260+ if (eval) {
261+ if (!eval.value ().isTensor ()) {
262+ LOG_DEBUG (ctx->logger , " Found the value to be: " << eval.value ());
263+ } else {
264+ LOG_DEBUG (ctx->logger , " Found the value to be a tensor (shape " << eval.value ().toTensor ().sizes () << ' )' );
265+ }
266+ ctx->AssociateValueAndIValue (n->output (0 ), eval.value ());
267+ }
268+ } else if (!blacklisted) {
206269 // Should error out if something fails
207270 AddLayer (ctx, n);
208271 } else {
@@ -237,22 +300,29 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
237300 return engine;
238301}
239302
240- bool VerifyConverterSupportForBlock (const torch::jit::Block* b) {
241- bool supported = true ;
303+ std::set<std::string> GetUnsupportedOpsInBlock (const torch::jit::Block* b ) {
242304 std::set<std::string> unsupported_ops;
243305 for (const auto n : b->nodes ()) {
244- if (!OpSupported (n)) {
306+ if (!OpSupported (n) && n-> kind () != torch::jit::prim::Loop ) {
245307 auto schema = n->maybeSchema ();
246308 TRTORCH_CHECK (schema, " Unable to get schema for Node " << util::node_info (n) \
247309 << " (conversion.VerifyCoverterSupportForBlock" );
248310 std::stringstream ss;
249311 ss << *schema;
250312 unsupported_ops.insert (ss.str ());
251- supported = false ;
313+ }
314+ for (const auto sub_b : n->blocks ()) {
315+ auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock (sub_b);
316+ unsupported_ops.insert (sub_b_unsupported_ops.begin (), sub_b_unsupported_ops.end ());
252317 }
253318 }
319+ return unsupported_ops;
320+ }
321+
322+ bool VerifyConverterSupportForBlock (const torch::jit::Block* b) {
323+ auto unsupported_ops = GetUnsupportedOpsInBlock (b);
254324
255- if (!supported ) {
325+ if (unsupported_ops. size () != 0 ) {
256326 std::stringstream unsupported_msg;
257327 unsupported_msg << " Method requested cannot be compiled by TRTorch.\n Unsupported operators listed below:" << std::endl;
258328 for (auto s : unsupported_ops) {
@@ -261,8 +331,10 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
261331 unsupported_msg << " You can either implement converters for these ops in your application or request implementation" << std::endl;
262332 unsupported_msg << " https://www.github.com/nvidia/TRTorch/issues" << std::endl;
263333 LOG_ERROR (unsupported_msg.str ());
334+ return false ;
335+ } else {
336+ return true ;
264337 }
265- return supported;
266338}
267339
268340} // namespace conversion
0 commit comments