Skip to content

Commit 95128b2

Browse files
authored
Revert "Support calling custom method names via MODULE_METHOD_NAME (fixes triton-inference-server/server#5209) (#127)" (#128)
This reverts commit 7b63f0f.
1 parent 7b63f0f commit 95128b2

File tree

2 files changed

+7
-72
lines changed

2 files changed

+7
-72
lines changed

Diff for: README.md

-14
Original file line numberDiff line numberDiff line change
@@ -217,20 +217,6 @@ key: "INTRA_OP_THREAD_COUNT"
217217
}
218218
```
219219

220-
* `MODULE_METHOD_NAME`:
221-
222-
String flag to specify which method on the PyTorch model is being called.
223-
Default value is `forward`.
224-
225-
```
226-
parameters: {
227-
key: "MODULE_METHOD_NAME"
228-
value: {
229-
string_value:"custom_method"
230-
}
231-
}
232-
```
233-
234220
* Additional Optimizations: Three additional boolean parameters are available to disable
235221
certain Torch optimizations that can sometimes cause latency regressions in models with
236222
complex execution modes and dynamic shapes. If not specified, all are enabled by default.

Diff for: src/libtorch.cc

+7-58
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@
6161
// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
6262
#include <ATen/Parallel.h>
6363

64-
// Default forward method to call on PyTorch modules
65-
const std::string DEFAULT_MODULE_METHOD_NAME = "forward";
6664

6765
//
6866
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
@@ -113,7 +111,6 @@ class ModelState : public BackendModel {
113111
{
114112
return model_outputs_;
115113
}
116-
const std::string& ModuleMethodName() { return module_method_name_; }
117114

118115
private:
119116
ModelState(TRITONBACKEND_Model* triton_model);
@@ -156,10 +153,6 @@ class ModelState : public BackendModel {
156153
// is specified both in the output section and state section, it indicates
157154
// that the backend must return the output state to the client too.
158155
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
159-
160-
// Method to call on PyTorch Module.
161-
// Defaults to DEFAULT_MODULE_METHOD_NAME.
162-
std::string module_method_name_;
163156
};
164157

165158
TRITONSERVER_Error*
@@ -237,8 +230,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
237230
enable_inference_mode_(true), enable_cache_cleaning_(false),
238231
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
239232
enable_jit_profiling_pair_({false, true}),
240-
enable_jit_executor_pair_({false, true}),
241-
module_method_name_(DEFAULT_MODULE_METHOD_NAME)
233+
enable_jit_executor_pair_({false, true})
242234
{
243235
}
244236

@@ -527,30 +519,6 @@ ModelState::ParseParameters()
527519
.c_str());
528520
}
529521
}
530-
531-
// If 'MODULE_METHOD_NAME' is not present in 'parameters' then
532-
// 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
533-
std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
534-
err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name);
535-
if (err != nullptr) {
536-
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
537-
return err;
538-
} else {
539-
LOG_MESSAGE(
540-
TRITONSERVER_LOG_INFO,
541-
(std::string("module_method_name is not specified") +
542-
" for model instance '" + Name() + "'")
543-
.c_str());
544-
TRITONSERVER_ErrorDelete(err);
545-
}
546-
} else {
547-
module_method_name_ = module_method_name;
548-
LOG_MESSAGE(
549-
TRITONSERVER_LOG_INFO,
550-
(std::string("module_method_name is ") + module_method_name_ +
551-
" for model instance '" + Name() + "'")
552-
.c_str());
553-
}
554522
}
555523

556524
return nullptr;
@@ -972,20 +940,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
972940
// configuration specifies only those.
973941
std::vector<std::string> allowed_inputs;
974942

975-
// First check if method exists in the model and throw an error if absent
976-
const auto methodNameToExecute = model_state_->ModuleMethodName();
977-
const auto optionalMethodHandle =
978-
torch_model_->find_method(methodNameToExecute);
979-
if (!optionalMethodHandle.has_value()) {
980-
return TRITONSERVER_ErrorNew(
981-
TRITONSERVER_ERROR_INVALID_ARG,
982-
(std::string("unable to find method '") + methodNameToExecute +
983-
"' in model '" + model_path_ + "'")
984-
.c_str());
985-
}
986-
987-
// Get the method schema and validate the inputs
988-
const torch::jit::Method& method = optionalMethodHandle.value();
943+
const torch::jit::Method& method = torch_model_->get_method("forward");
989944
const auto& schema = method.function().getSchema();
990945
const std::vector<c10::Argument>& arguments = schema.arguments();
991946

@@ -1628,24 +1583,18 @@ ModelInstanceState::Execute(
16281583
torch::NoGradGuard no_grad;
16291584

16301585
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1631-
std::string module_method_name = model_state_->ModuleMethodName();
1632-
std::vector<c10::IValue> inputs;
16331586
if (is_dict_input_) {
1634-
c10::Dict<std::string, at::Tensor> dict;
1587+
torch::Dict<std::string, torch::Tensor> input_dict;
16351588
for (auto& input_index : input_index_map_) {
16361589
torch::jit::IValue ival = (*input_tensors)[input_index.second];
1637-
dict.insert(input_index.first, ival.toTensor());
1590+
input_dict.insert(input_index.first, ival.toTensor());
16381591
}
1639-
inputs.push_back(dict);
1592+
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1593+
model_outputs_ = torch_model_->forward(input_dict_ivalue);
16401594
} else {
1641-
for (auto& input_tensor : *input_tensors) {
1642-
inputs.push_back(input_tensor.toTensor());
1643-
}
1595+
model_outputs_ = torch_model_->forward(*input_tensors);
16441596
}
16451597

1646-
// Actually run the method on the model.
1647-
model_outputs_ = torch_model_->get_method(module_method_name)(inputs);
1648-
16491598
if (model_outputs_.isTuple()) {
16501599
auto model_outputs_tuple = model_outputs_.toTuple();
16511600
size_t op_index = 0;

0 commit comments

Comments
 (0)