Skip to content

Commit 9f2ffd0

Browse files
committed
feat: Add aten::full converter, quantization ops testcases
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 339919d commit 9f2ffd0

File tree

10 files changed

+145
-54
lines changed

10 files changed

+145
-54
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7171
}
7272
input_type = nvinfer1::DataType::kFLOAT;
7373
// Networks trained with Quantization aware training approach don't need a calibrator as they have Q/DQ nodes.
74-
if (!settings.calibrator){
75-
LOG_WARNING("Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
74+
if (!settings.calibrator) {
75+
LOG_WARNING(
76+
"Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
7677
}
7778
break;
7879
case nvinfer1::DataType::kFLOAT:

core/conversion/converters/impl/constant.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/torch.h>
12
#include "core/conversion/converters/converters.h"
23
#include "core/util/prelude.h"
34

@@ -25,6 +26,18 @@ auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
2526

2627
LOG_DEBUG("Output tensor shape: " << const_out->getDimensions());
2728

29+
return true;
30+
}})
31+
.pattern({"aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
32+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
33+
auto size = args[0].unwrapToIntList();
34+
auto scalar = args[1].unwrapToScalar().to<float>();
35+
auto scalar_tensor = torch::full({5}, scalar);
36+
auto full_tensor = tensor_to_const(ctx, scalar_tensor);
37+
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], full_tensor);
38+
39+
LOG_DEBUG("Output tensor shape: " << output->getDimensions());
40+
2841
return true;
2942
}});
3043
// clang-format on

core/conversion/converters/impl/quantization.cpp

+11-9
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@ namespace {
1313
auto quantization_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1414
.pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)",
1515
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16+
// This aten operator is generated from torch.fake_quantize_per_tensor_affine op in Pytorch python API.
17+
// Example usage: https://github.com/pytorch/pytorch/blob/master/torch/quantization/fake_quantize.py#L145
1618
auto input = args[0].ITensorOrFreeze(ctx);
1719
auto scale = args[1].unwrapToScalar().to<float>();
1820
auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale}));
19-
2021
// Add and configure a QuantizeLayer.
2122
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor);
22-
// Set an invalid axis
23-
quantize_layer->setAxis(1);
23+
quantize_layer->setAxis(0);
2424

25-
// Add and configure DequantizeLayer
25+
// Add and configure DequantizeLayer following a QuantizeLayer
2626
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor);
27-
dequantize_layer->setAxis(1);
27+
dequantize_layer->setAxis(0);
2828

2929
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
3030
LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions());
@@ -33,17 +33,19 @@ auto quantization_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
3333
}})
3434
.pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)",
3535
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
36+
// This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API.
37+
// Example usage: https://github.com/pytorch/pytorch/blob/master/torch/quantization/fake_quantize.py#L141
3638
auto input = args[0].ITensorOrFreeze(ctx);
3739
auto scale = args[1].ITensorOrFreeze(ctx);
38-
40+
int64_t axis = args[3].unwrapToScalar().to<int64_t>();
3941
// Add and configure a QuantizeLayer.
4042
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale);
41-
// Set a channel axis=0 which represents output channels
42-
quantize_layer->setAxis(0);
43+
// Set a channel axis which represents output channels
44+
quantize_layer->setAxis(axis);
4345

4446
// Add and configure a DequantizeLayer.
4547
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale);
46-
dequantize_layer->setAxis(0);
48+
dequantize_layer->setAxis(axis);
4749
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
4850

4951
LOG_DEBUG("[fake_quantize_per_channel_affine] Ouput tensor shape: " << qdq_out->getDimensions());

core/plugins/impl/interpolate_plugin.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ bool InterpolatePlugin::supportsFormatCombination(
206206
const nvinfer1::PluginTensorDesc* inOut,
207207
int nbInputs,
208208
int nbOutputs) noexcept {
209-
210209
if (nbInputs != 1) {
211210
LOG_ERROR("Expected a single tensor as input to interpolate plugin");
212211
}

core/util/jit_util.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ inline std::string node_info(const torch::jit::Node* n) {
1313
std::stringstream ss;
1414
ss << *n;
1515
std::string node_info = ss.str();
16-
// Nodes in torchscript graph have file name and line numbers commented for every node. Remove that when returning a node name for easier readability.
16+
// Nodes in torchscript graph have file name and line numbers commented for every node. Remove that when returning a
17+
// node name for easier readability.
1718
node_info = node_info.substr(0, node_info.find("#", 0));
1819
node_info.erase(std::remove(node_info.begin(), node_info.end(), '\n'), node_info.end());
1920
return node_info;

cpp/trtorchexec/main.cpp

+44-39
Original file line numberDiff line numberDiff line change
@@ -56,54 +56,59 @@ int main(int argc, const char* argv[]) {
5656
}
5757

5858
auto compile_spec = trtorch::CompileSpec(dims);
59+
// compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
5960
compile_spec.workspace_size = 1 << 24;
60-
61-
std::cout << "Checking operator support" << std::endl;
62-
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
63-
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
64-
return -1;
65-
}
66-
67-
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
61+
compile_spec.op_precision = torch::kChar;
62+
// compile_spec.input_dtypes = {torch::kInt32, torch::kInt32};
63+
// std::cout << "===Compile Spec: " << compile_spec << std::endl;
64+
// compile_spec.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
65+
// compile_spec.torch_fallback.min_block_size = 1;
66+
// std::cout << "Checking operator support" << std::endl;
67+
// if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
68+
// std::cerr << "Method is not currently supported by TRTorch" << std::endl;
69+
// return -1;
70+
// }
71+
//
72+
// std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6873
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);
6974
std::ofstream out("/tmp/engine_converted_from_jit.trt");
7075
out << engine;
7176
out.close();
7277

73-
std::vector<torch::jit::IValue> jit_inputs_ivalues;
74-
std::vector<torch::jit::IValue> trt_inputs_ivalues;
75-
auto in = at::randint(5, dims[0], {at::kCUDA});
76-
jit_inputs_ivalues.push_back(in.clone());
77-
trt_inputs_ivalues.push_back(in.clone());
78-
79-
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
80-
std::vector<at::Tensor> jit_results;
81-
if (jit_results_ivalues.isTensor()) {
82-
jit_results.push_back(jit_results_ivalues.toTensor());
83-
} else {
84-
auto results = jit_results_ivalues.toTuple()->elements();
85-
for (auto r : results) {
86-
jit_results.push_back(r.toTensor());
87-
}
88-
}
78+
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
79+
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
80+
// auto in = at::randint(5, dims[0], {at::kCUDA});
81+
// jit_inputs_ivalues.push_back(in.clone());
82+
// trt_inputs_ivalues.push_back(in.clone());
83+
// //
84+
// torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
85+
// std::vector<at::Tensor> jit_results;
86+
// if (jit_results_ivalues.isTensor()) {
87+
// jit_results.push_back(jit_results_ivalues.toTensor());
88+
// } else {
89+
// auto results = jit_results_ivalues.toTuple()->elements();
90+
// for (auto r : results) {
91+
// jit_results.push_back(r.toTensor());
92+
// }
93+
// }
8994

9095
std::cout << "Compiling graph as module" << std::endl;
9196
auto trt_mod = trtorch::CompileGraph(mod, compile_spec);
92-
std::cout << "Running TRT module" << std::endl;
93-
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
94-
std::vector<at::Tensor> trt_results;
95-
if (trt_results_ivalues.isTensor()) {
96-
trt_results.push_back(trt_results_ivalues.toTensor());
97-
} else {
98-
auto results = trt_results_ivalues.toTuple()->elements();
99-
for (auto r : results) {
100-
trt_results.push_back(r.toTensor());
101-
}
102-
}
103-
104-
for (size_t i = 0; i < trt_results.size(); i++) {
105-
almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));
106-
}
97+
// std::cout << "Running TRT module" << std::endl;
98+
// torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
99+
// std::vector<at::Tensor> trt_results;
100+
// if (trt_results_ivalues.isTensor()) {
101+
// trt_results.push_back(trt_results_ivalues.toTensor());
102+
// } else {
103+
// auto results = trt_results_ivalues.toTuple()->elements();
104+
// for (auto r : results) {
105+
// trt_results.push_back(r.toTensor());
106+
// }
107+
// }
108+
//
109+
// for (size_t i = 0; i < trt_results.size(); i++) {
110+
// almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]));
111+
// }
107112

108113
std::cout << "Converted Engine saved to /tmp/engine_converted_from_jit.trt" << std::endl;
109114

tests/core/conversion/converters/BUILD

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ converter_test(
5959
name = "test_pooling",
6060
)
6161

62+
converter_test(
63+
name = "test_quantization",
64+
)
65+
6266
converter_test(
6367
name = "test_reduce",
6468
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <string>
2+
#include "NvInfer.h"
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenFakeQuantizePerTensorConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x.1 : Tensor):
11+
%7 : int = prim::Constant[value=-128]()
12+
%3 : float = prim::Constant[value=6.]()
13+
%4 : int = prim::Constant[value=0]()
14+
%8 : int = prim::Constant[value=127]()
15+
%quant_input.1 : Tensor = aten::fake_quantize_per_tensor_affine(%x.1, %3, %4, %7, %8)
16+
return (%quant_input.1))IR";
17+
18+
auto g = std::make_shared<torch::jit::Graph>();
19+
torch::jit::parseIR(graph, g.get());
20+
21+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}).to(at::kFloat);
22+
23+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
25+
26+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
27+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}, nvinfer1::DataType::kINT8);
28+
29+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30+
}
31+
32+
TEST(Converters, ATenFakeQuantizePerChannelConvertsCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%x.1 : Tensor):
35+
%22 : int = prim::Constant[value=-128]()
36+
%14 : int = prim::Constant[value=4]()
37+
%9 : None = prim::Constant()
38+
%35 : Device = prim::Constant[value="cuda:0"]()
39+
%6 : int = prim::Constant[value=6]()
40+
%3 : int = prim::Constant[value=5]()
41+
%5 : float = prim::Constant[value=3.5]()
42+
%13 : int = prim::Constant[value=1]()
43+
%23 : int = prim::Constant[value=127]()
44+
%4 : int[] = prim::ListConstruct(%3)
45+
%11 : Tensor = aten::full(%4, %5, %6, %9, %35, %9)
46+
%12 : int[] = prim::ListConstruct(%3)
47+
%19 : Tensor = aten::full(%12, %13, %14, %9, %35, %9)
48+
%quant_input.1 : Tensor = aten::fake_quantize_per_channel_affine(%x.1, %11, %19, %13, %22, %23)
49+
return (%quant_input.1))IR";
50+
51+
auto g = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(graph, g.get());
53+
54+
auto in = at::randint(1, 10, {1, 5, 3, 3}, {at::kCUDA});
55+
56+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
57+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
58+
59+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
60+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}, nvinfer1::DataType::kINT8);
61+
62+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
63+
}

tests/util/run_graph_engine.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inpu
6363
std::vector<at::Tensor> RunGraphEngine(
6464
std::shared_ptr<torch::jit::Graph>& g,
6565
core::conversion::GraphParams& named_params,
66-
std::vector<at::Tensor> inputs) {
66+
std::vector<at::Tensor> inputs,
67+
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT) {
6768
LOG_DEBUG("Running TRT version");
6869
auto in = toInputRanges(inputs);
6970
auto info = core::conversion::ConversionInfo(in);
7071
info.engine_settings.workspace_size = 1 << 20;
72+
info.engine_settings.op_precision = op_precision;
7173
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
7274
return RunEngine(eng, inputs);
7375
}

tests/util/util.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ std::vector<at::Tensor> RunGraph(
2828
std::vector<at::Tensor> RunGraphEngine(
2929
std::shared_ptr<torch::jit::Graph>& g,
3030
core::conversion::GraphParams& named_params,
31-
std::vector<at::Tensor> inputs);
31+
std::vector<at::Tensor> inputs,
32+
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT);
3233

3334
// Runs an arbitrary JIT graph with dynamic input sizes by converting it to
3435
// TensorRT and running inference and returns results

0 commit comments

Comments
 (0)