|
61 | 61 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
|
62 | 62 | #include <ATen/Parallel.h>
|
63 | 63 |
|
64 |
| -// Default forward method to call on PyTorch modules |
65 |
| -const std::string DEFAULT_MODULE_METHOD_NAME = "forward"; |
66 | 64 |
|
67 | 65 | //
|
68 | 66 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
|
@@ -113,7 +111,6 @@ class ModelState : public BackendModel {
|
113 | 111 | {
|
114 | 112 | return model_outputs_;
|
115 | 113 | }
|
116 |
| - const std::string& ModuleMethodName() { return module_method_name_; } |
117 | 114 |
|
118 | 115 | private:
|
119 | 116 | ModelState(TRITONBACKEND_Model* triton_model);
|
@@ -156,10 +153,6 @@ class ModelState : public BackendModel {
|
156 | 153 | // is specified both in the output section and state section, it indicates
|
157 | 154 | // that the backend must return the output state to the client too.
|
158 | 155 | std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
|
159 |
| - |
160 |
| - // Method to call on PyTorch Module. |
161 |
| - // Defaults to DEFAULT_MODULE_METHOD_NAME. |
162 |
| - std::string module_method_name_; |
163 | 156 | };
|
164 | 157 |
|
165 | 158 | TRITONSERVER_Error*
|
@@ -237,8 +230,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
237 | 230 | enable_inference_mode_(true), enable_cache_cleaning_(false),
|
238 | 231 | enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
|
239 | 232 | enable_jit_profiling_pair_({false, true}),
|
240 |
| - enable_jit_executor_pair_({false, true}), |
241 |
| - module_method_name_(DEFAULT_MODULE_METHOD_NAME) |
| 233 | + enable_jit_executor_pair_({false, true}) |
242 | 234 | {
|
243 | 235 | }
|
244 | 236 |
|
@@ -527,30 +519,6 @@ ModelState::ParseParameters()
|
527 | 519 | .c_str());
|
528 | 520 | }
|
529 | 521 | }
|
530 |
| - |
531 |
| - // If 'MODULE_METHOD_NAME' is not present in 'parameters' then |
532 |
| - // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward'). |
533 |
| - std::string module_method_name = DEFAULT_MODULE_METHOD_NAME; |
534 |
| - err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name); |
535 |
| - if (err != nullptr) { |
536 |
| - if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
537 |
| - return err; |
538 |
| - } else { |
539 |
| - LOG_MESSAGE( |
540 |
| - TRITONSERVER_LOG_INFO, |
541 |
| - (std::string("module_method_name is not specified") + |
542 |
| - " for model instance '" + Name() + "'") |
543 |
| - .c_str()); |
544 |
| - TRITONSERVER_ErrorDelete(err); |
545 |
| - } |
546 |
| - } else { |
547 |
| - module_method_name_ = module_method_name; |
548 |
| - LOG_MESSAGE( |
549 |
| - TRITONSERVER_LOG_INFO, |
550 |
| - (std::string("module_method_name is ") + module_method_name_ + |
551 |
| - " for model instance '" + Name() + "'") |
552 |
| - .c_str()); |
553 |
| - } |
554 | 522 | }
|
555 | 523 |
|
556 | 524 | return nullptr;
|
@@ -972,20 +940,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
|
972 | 940 | // configuration specifies only those.
|
973 | 941 | std::vector<std::string> allowed_inputs;
|
974 | 942 |
|
975 |
| - // First check if method exists in the model and throw an error if absent |
976 |
| - const auto methodNameToExecute = model_state_->ModuleMethodName(); |
977 |
| - const auto optionalMethodHandle = |
978 |
| - torch_model_->find_method(methodNameToExecute); |
979 |
| - if (!optionalMethodHandle.has_value()) { |
980 |
| - return TRITONSERVER_ErrorNew( |
981 |
| - TRITONSERVER_ERROR_INVALID_ARG, |
982 |
| - (std::string("unable to find method '") + methodNameToExecute + |
983 |
| - "' in model '" + model_path_ + "'") |
984 |
| - .c_str()); |
985 |
| - } |
986 |
| - |
987 |
| - // Get the method schema and validate the inputs |
988 |
| - const torch::jit::Method& method = optionalMethodHandle.value(); |
| 943 | + const torch::jit::Method& method = torch_model_->get_method("forward"); |
989 | 944 | const auto& schema = method.function().getSchema();
|
990 | 945 | const std::vector<c10::Argument>& arguments = schema.arguments();
|
991 | 946 |
|
@@ -1628,24 +1583,18 @@ ModelInstanceState::Execute(
|
1628 | 1583 | torch::NoGradGuard no_grad;
|
1629 | 1584 |
|
1630 | 1585 | // If input is a dictionary, prepare dictionary from 'input_tensors'.
|
1631 |
| - std::string module_method_name = model_state_->ModuleMethodName(); |
1632 |
| - std::vector<c10::IValue> inputs; |
1633 | 1586 | if (is_dict_input_) {
|
1634 |
| - c10::Dict<std::string, at::Tensor> dict; |
| 1587 | + torch::Dict<std::string, torch::Tensor> input_dict; |
1635 | 1588 | for (auto& input_index : input_index_map_) {
|
1636 | 1589 | torch::jit::IValue ival = (*input_tensors)[input_index.second];
|
1637 |
| - dict.insert(input_index.first, ival.toTensor()); |
| 1590 | + input_dict.insert(input_index.first, ival.toTensor()); |
1638 | 1591 | }
|
1639 |
| - inputs.push_back(dict); |
| 1592 | + std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict}; |
| 1593 | + model_outputs_ = torch_model_->forward(input_dict_ivalue); |
1640 | 1594 | } else {
|
1641 |
| - for (auto& input_tensor : *input_tensors) { |
1642 |
| - inputs.push_back(input_tensor.toTensor()); |
1643 |
| - } |
| 1595 | + model_outputs_ = torch_model_->forward(*input_tensors); |
1644 | 1596 | }
|
1645 | 1597 |
|
1646 |
| - // Actually run the method on the model. |
1647 |
| - model_outputs_ = torch_model_->get_method(module_method_name)(inputs); |
1648 |
| - |
1649 | 1598 | if (model_outputs_.isTuple()) {
|
1650 | 1599 | auto model_outputs_tuple = model_outputs_.toTuple();
|
1651 | 1600 | size_t op_index = 0;
|
|
0 commit comments