Skip to content

Commit

Permalink
feat(//py): Gate partial compilation from to_backend API
Browse files Browse the repository at this point in the history
We cant run partial compilation on modules from the to_backend API
because we are expected to simply return a handle to a TRT engine vs
return a full graph. Therefore we cannot do graph stitching. Now an
exception will be thrown if someone tries to use fallback and to_backend
directing them towards trtorch.compile

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 11, 2021
1 parent 0a3258d commit bf1b2d8
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 72 deletions.
65 changes: 34 additions & 31 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,37 +259,40 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
backend_spec = torch.classes.tensorrt.CompileSpec()

for i in parsed_spec.input_ranges:
ir = torch.classes.tensorrt.InputRange()
ir.set_min(i.min)
ir.set_opt(i.opt)
ir.set_max(i.max)
backend_spec.append_input_range(ir)

d = torch.classes.tensorrt.Device()
d.set_device_type(int(parsed_spec.device.device_type))
d.set_gpu_id(parsed_spec.device.gpu_id)
d.set_dla_core(parsed_spec.device.dla_core)
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

torch_fallback = torch.classes.tensorrt.TorchFallback()
torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled)
torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size)
torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)

backend_spec.set_device(d)
backend_spec.set_torch_fallback(fallback)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_debug(parsed_spec.debug)
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_strict_types(parsed_spec.strict_types)
backend_spec.set_capability(int(parsed_spec.capability))
backend_spec.set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
backend_spec.set_workspace_size(parsed_spec.workspace_size)
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
ir = torch.classes.tensorrt._InputRange()
ir._set_min(i.min)
ir._set_opt(i.opt)
ir._set_max(i.max)
backend_spec._append_input_range(ir)

d = torch.classes.tensorrt._Device()
d._set_device_type(int(parsed_spec.device.device_type))
d._set_gpu_id(parsed_spec.device.gpu_id)
d._set_dla_core(parsed_spec.device.dla_core)
d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

if parsed_spec.torch_fallback.enabled:
raise RuntimeError("Partial module compilation is not currently supported via the PyTorch to_backend API integration. If you need partial compilation, use trtorch.compile")

torch_fallback = torch.classes.tensorrt._TorchFallback()
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size)
torch_fallback._set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)

backend_spec._set_device(d)
backend_spec._set_torch_fallback(torch_fallback)
backend_spec._set_op_precision(int(parsed_spec.op_precision))
backend_spec._set_disable_tf32(parsed_spec.disable_tf32)
backend_spec._set_refit(parsed_spec.refit)
backend_spec._set_debug(parsed_spec.debug)
backend_spec._set_refit(parsed_spec.refit)
backend_spec._set_strict_types(parsed_spec.strict_types)
backend_spec._set_capability(int(parsed_spec.capability))
backend_spec._set_num_min_timing_iters(parsed_spec.num_min_timing_iters)
backend_spec._set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
backend_spec._set_workspace_size(parsed_spec.workspace_size)
backend_spec._set_max_batch_size(parsed_spec.max_batch_size)
backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())

return backend_spec
23 changes: 15 additions & 8 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,44 @@ namespace backend {
namespace {

#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
(registry).def("set_" #field_name, &class_name::set_##field_name); \
(registry).def("get_" #field_name, &class_name::get_##field_name);
(registry).def("_set_" #field_name, &class_name::set_##field_name); \
(registry).def("_get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
.def(torch::init<>())
.def("__str__", &trtorch::pyapi::InputRange::to_str);

ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);

static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());
torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
.def(torch::init<>())
.def("__str__", &trtorch::pyapi::Device::to_str);

ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);

static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "Fallback").def(torch::init<>());
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
.def(torch::init<>())
.def("__str__", &trtorch::pyapi::TorchFallback::to_str);

ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);

static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
.def("_append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

Expand Down
78 changes: 47 additions & 31 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace trtorch {
namespace pyapi {

std::string to_str(InputRange& value) {
std::string InputRange::to_str() {
auto vec_to_str = [](std::vector<int64_t> shape) -> std::string {
std::stringstream ss;
ss << '[';
Expand All @@ -17,9 +17,9 @@ std::string to_str(InputRange& value) {

std::stringstream ss;
ss << " {" << std::endl;
ss << " min: " << vec_to_str(value.min) << ',' << std::endl;
ss << " opt: " << vec_to_str(value.opt) << ',' << std::endl;
ss << " max: " << vec_to_str(value.max) << ',' << std::endl;
ss << " min: " << vec_to_str(min) << ',' << std::endl;
ss << " opt: " << vec_to_str(opt) << ',' << std::endl;
ss << " max: " << vec_to_str(max) << ',' << std::endl;
ss << " }" << std::endl;
return ss.str();
}
Expand Down Expand Up @@ -68,6 +68,18 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
}
}

std::string Device::to_str() {
std::stringstream ss;
std::string fallback = allow_gpu_fallback ? "True" : "False";
ss << " {" << std::endl;
ss << " \"device_type\": " << pyapi::to_str(device_type) << std::endl;
ss << " \"allow_gpu_fallback\": " << fallback << std::endl;
ss << " \"gpu_id\": " << gpu_id << std::endl;
ss << " \"dla_core\": " << dla_core << std::endl;
ss << " }" << std::endl;
return ss.str();
}

std::string to_str(EngineCapability value) {
switch (value) {
case EngineCapability::kSAFE_GPU:
Expand All @@ -92,6 +104,21 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
}
}

std::string TorchFallback::to_str() {
std::stringstream ss;
std::string e = enabled ? "True" : "False";
ss << " {" << std::endl;
ss << " \"enabled\": " << e << std::endl;
ss << " \"min_block_size\": " << min_block_size << std::endl;
ss << " \"forced_fallback_operators\": [" << std::endl;
for (auto i : forced_fallback_operators) {
ss << " " << i << ',' << std::endl;
}
ss << " ]" << std::endl;
ss << " }" << std::endl;
return ss.str();
}

core::CompileSpec CompileSpec::toInternalCompileSpec() {
std::vector<core::ir::InputRange> internal_input_ranges;
for (auto i : input_ranges) {
Expand Down Expand Up @@ -128,36 +155,25 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
std::string CompileSpec::stringify() {
std::stringstream ss;
ss << "TensorRT Compile Spec: {" << std::endl;
ss << " \"Input Shapes\": [" << std::endl;
ss << " \"Input Shapes\": [" << std::endl;
for (auto i : input_ranges) {
ss << to_str(i);
ss << i.to_str();
}
std::string enabled = torch_fallback.enabled ? "True" : "False";
ss << " ]" << std::endl;
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
ss << " \"Refit\": " << refit << std::endl;
ss << " \"Debug\": " << debug << std::endl;
ss << " \"Strict Types\": " << strict_types << std::endl;
ss << " \"Device Type: " << to_str(device.device_type) << std::endl;
ss << " \"GPU ID: " << device.gpu_id << std::endl;
ss << " \"DLA Core: " << device.dla_core << std::endl;
ss << " \"Allow GPU Fallback\": " << device.allow_gpu_fallback << std::endl;
ss << " \"Engine Capability\": " << to_str(capability) << std::endl;
ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl;
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
ss << " \"Workspace Size\": " << workspace_size << std::endl;
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
ss << " \"Torch Fallback: {" << std::endl;
ss << " \"enabled\": " << enabled << std::endl;
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
ss << " \"forced_fallback_operators\": [" << std::endl;
for (auto i : torch_fallback.forced_fallback_operators) {
ss << " " << i << ',' << std::endl;
}
ss << " ]" << std::endl;
ss << " }" << std::endl;
ss << " ]" << std::endl;
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
ss << " \"Refit\": " << refit << std::endl;
ss << " \"Debug\": " << debug << std::endl;
ss << " \"Strict Types\": " << strict_types << std::endl;
ss << " \"Device\": " << device.to_str() << std::endl;
ss << " \"Engine Capability\": " << to_str(capability) << std::endl;
ss << " \"Num Min Timing Iters\": " << num_min_timing_iters << std::endl;
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
ss << " \"Workspace Size\": " << workspace_size << std::endl;
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
ss << " \"Torch Fallback\": " << torch_fallback.to_str();
ss << "}";
return ss.str();
}
Expand Down
9 changes: 7 additions & 2 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ struct InputRange : torch::CustomClassHolder {
ADD_FIELD_GET_SET(min, std::vector<int64_t>);
ADD_FIELD_GET_SET(opt, std::vector<int64_t>);
ADD_FIELD_GET_SET(max, std::vector<int64_t>);
};

std::string to_str(InputRange& value);
std::string to_str();
};

enum class DataType : int8_t {
kFloat,
Expand Down Expand Up @@ -73,6 +73,8 @@ struct Device : torch::CustomClassHolder {
ADD_FIELD_GET_SET(gpu_id, int64_t);
ADD_FIELD_GET_SET(dla_core, int64_t);
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);

std::string to_str();
};

std::string to_str(DeviceType value);
Expand All @@ -87,8 +89,11 @@ struct TorchFallback : torch::CustomClassHolder {
ADD_FIELD_GET_SET(enabled, bool);
ADD_FIELD_GET_SET(min_block_size, int64_t);
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);

std::string to_str();
};


enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
Expand Down
4 changes: 4 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void log(core::util::logging::LogLevel lvl, const std::string& msg) {
PYBIND11_MODULE(_C, m) {
py::class_<InputRange>(m, "InputRange")
.def(py::init<>())
.def("__str__", &trtorch::pyapi::InputRange::to_str)
.def_readwrite("min", &InputRange::min)
.def_readwrite("opt", &InputRange::opt)
.def_readwrite("max", &InputRange::max);
Expand Down Expand Up @@ -237,6 +238,7 @@ PYBIND11_MODULE(_C, m) {

py::class_<CompileSpec>(m, "CompileSpec")
.def(py::init<>())
.def("__str__", &trtorch::pyapi::CompileSpec::stringify)
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
Expand All @@ -256,13 +258,15 @@ PYBIND11_MODULE(_C, m) {

py::class_<Device>(m, "Device")
.def(py::init<>())
.def("__str__", &trtorch::pyapi::Device::to_str)
.def_readwrite("device_type", &Device::device_type)
.def_readwrite("gpu_id", &Device::gpu_id)
.def_readwrite("dla_core", &Device::dla_core)
.def_readwrite("allow_gpu_fallback", &Device::allow_gpu_fallback);

py::class_<TorchFallback>(m, "TorchFallback")
.def(py::init<>())
.def("__str__", &trtorch::pyapi::TorchFallback::to_str)
.def_readwrite("enabled", &TorchFallback::enabled)
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);
Expand Down

0 comments on commit bf1b2d8

Please sign in to comment.