@@ -103,6 +103,7 @@ class ModelState : public BackendModel {
103
103
104
104
bool EnabledWeightSharing () { return enable_weight_sharing_; }
105
105
const std::vector<std::string>& ModelOutputs () { return output_names_; }
106
+ const std::string& MethodToCall () { return method_to_call_; }
106
107
107
108
private:
108
109
ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +146,10 @@ class ModelState : public BackendModel {
145
146
// List of all the outputs specified in the output section of model
146
147
// configuration.
147
148
std::vector<std::string> output_names_;
149
+
150
+ // Method to call on PyTorch Module.
151
+ // Defaults to "forward".
152
+ std::string method_to_call_;
148
153
};
149
154
150
155
TRITONSERVER_Error*
@@ -180,7 +185,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180
185
enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181
186
enable_jit_profiling_pair_({false , true }),
182
187
enable_jit_executor_pair_({false , true }),
183
- enable_nvfuser_pair_({false , false })
188
+ enable_nvfuser_pair_({false , false }),
189
+ method_to_call_(" forward" )
184
190
{
185
191
output_names_.clear ();
186
192
@@ -454,6 +460,29 @@ ModelState::ParseParameters()
454
460
" for model instance '" + Name () + " '" )
455
461
.c_str ());
456
462
}
463
+
464
+ // If 'ENABLE_NVFUSER' is not present in 'parameters' then no
465
+ // update is made to 'enable_nvfuser'.
466
+ std::string method_to_call = " forward" ;
467
+ err = GetParameterValue (params, " METHOD_TO_CALL" , &method_to_call);
468
+ if (err != nullptr ) {
469
+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
470
+ return err;
471
+ } else {
472
+ LOG_MESSAGE (
473
+ TRITONSERVER_LOG_INFO, (std::string (" method_to_call is not specified" ) +
474
+ " for model instance '" + Name () + " '" )
475
+ .c_str ());
476
+ TRITONSERVER_ErrorDelete (err);
477
+ }
478
+ } else {
479
+ method_to_call_ = std::string (" forward" );
480
+ LOG_MESSAGE (
481
+ TRITONSERVER_LOG_INFO, (std::string (" method_to_call is " ) +
482
+ method_to_call_ +
483
+ " for model instance '" + Name () + " '" )
484
+ .c_str ());
485
+ }
457
486
}
458
487
459
488
return nullptr ;
@@ -764,7 +793,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764
793
// configuration specifies only those.
765
794
std::vector<std::string> allowed_inputs;
766
795
767
- const torch::jit::Method& method = torch_model_->get_method (" forward " );
796
+ const torch::jit::Method& method = torch_model_->get_method (model_state_-> MethodToCall () );
768
797
const auto & schema = method.function ().getSchema ();
769
798
const std::vector<c10::Argument>& arguments = schema.arguments ();
770
799
@@ -1324,16 +1353,23 @@ ModelInstanceState::Execute(
1324
1353
torch::NoGradGuard no_grad;
1325
1354
1326
1355
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1356
+ std::string method_to_call = model_state_->MethodToCall ();
1327
1357
if (is_dict_input_) {
1328
1358
torch::Dict<std::string, torch::Tensor> input_dict;
1329
1359
for (auto & input_index : input_index_map_) {
1330
1360
torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331
1361
input_dict.insert (input_index.first , ival.toTensor ());
1332
1362
}
1333
- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334
- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1363
+ auto typ = c10::DictType::create (c10::StringType::get (), c10::TensorType::get ());
1364
+ auto inp = c10::impl::GenericList (typ);
1365
+ inp.emplace_back (input_dict);
1366
+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
1335
1367
} else {
1336
- model_outputs_ = torch_model_->forward (*input_tensors);
1368
+ auto inp = c10::impl::GenericList (c10::TensorType::get ());
1369
+ for (auto & input_tensor : *input_tensors) {
1370
+ inp.emplace_back (input_tensor.toTensor ());
1371
+ }
1372
+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
1337
1373
}
1338
1374
1339
1375
if (model_outputs_.isTuple ()) {
0 commit comments