Skip to content

Commit

Permalink
feat: Enable sparsity support in TRTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 authored and narendasan committed Aug 7, 2021
1 parent 2f23d6e commit f9e1f2b
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 8 deletions.
4 changes: 4 additions & 0 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
}

if (settings.sparse_weights) {
cfg->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
}

if (settings.refit) {
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
}
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct Device {

struct BuilderSettings {
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
std::vector<nvinfer1::DataType> input_dtypes;
bool sparse_weights = false;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
Expand Down
5 changes: 5 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ struct TRTORCH_API CompileSpec {
*/
bool disable_tf32 = false;

/**
* Enable sparsity for weights of conv and FC layers
*/
bool sparse_weights = false;

/**
* Build a refitable engine
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
}
}

internal.convert_info.engine_settings.sparse_weights = external.sparse_weights;
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
internal.convert_info.engine_settings.refit = external.refit;
internal.convert_info.engine_settings.debug = external.debug;
Expand Down
11 changes: 6 additions & 5 deletions cpp/trtorchexec/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ int main(int argc, const char* argv[]) {

auto compile_spec = trtorch::CompileSpec(dims);
compile_spec.workspace_size = 1 << 24;
compile_spec.sparse_weights = true;

std::cout << "Checking operator support" << std::endl;
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
return -1;
}
// std::cout << "Checking operator support" << std::endl;
// if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
// std::cerr << "Method is not currently supported by TRTorch" << std::endl;
// return -1;
// }

std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);
Expand Down
8 changes: 6 additions & 2 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "calibrator" in compile_spec:
info.ptq_calibrator = compile_spec["calibrator"]

if "sparse_weights" in compile_spec:
assert isinstance(compile_spec["sparse_weights"], bool)
info.sparse_weights = compile_spec["sparse_weights"]

if "disable_tf32" in compile_spec:
assert isinstance(compile_spec["disable_tf32"], bool)
info.disable_tf32 = compile_spec["disable_tf32"]
Expand Down Expand Up @@ -282,8 +286,8 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"enabled_precisions": {torch.half}, # Operating precision set to FP16
"sparse_weights": Enable sparsity for convolution and fully connected layers.
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": False, # enable refit
"debug": False, # enable debuggable engine
Expand Down
3 changes: 3 additions & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": Enable sparsity for convolution and fully connected layers.
"refit": false, # enable refit
"debug": false, # enable debuggable engine
"strict_types": false, # kernels should strictly run in operating precision
Expand Down Expand Up @@ -113,6 +115,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": Enable sparsity for convolution and fully connected layers.
"refit": false, # enable refit
"debug": false, # enable debuggable engine
"strict_types": false, # kernels should strictly run in operating precision
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void RegisterTRTCompileSpec() {
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
}

info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.convert_info.engine_settings.sparse_weights = sparse_weights;
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;
Expand Down Expand Up @@ -222,6 +223,7 @@ std::string CompileSpec::stringify() {
}
ss << " ]" << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
ss << " \"Sparsity\": " << sparse_weights << std::endl;
ss << " \"Refit\": " << refit << std::endl;
ss << " \"Debug\": " << debug << std::endl;
ss << " \"Strict Types\": " << strict_types << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ struct CompileSpec : torch::CustomClassHolder {
}

ADD_FIELD_GET_SET(disable_tf32, bool);
ADD_FIELD_GET_SET(sparse_weights, bool);
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_FIELD_GET_SET(strict_types, bool);
Expand All @@ -155,6 +156,7 @@ struct CompileSpec : torch::CustomClassHolder {
std::vector<Input> inputs;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
std::set<DataType> enabled_precisions = {DataType::kFloat};
bool sparse_weights = false;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
.def_readwrite("refit", &CompileSpec::refit)
.def_readwrite("sparse_weights", &CompileSpec::sparse_weights)
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
.def_readwrite("debug", &CompileSpec::debug)
.def_readwrite("strict_types", &CompileSpec::strict_types)
Expand Down

0 comments on commit f9e1f2b

Please sign in to comment.