Skip to content

Commit

Permalink
Merge branch 'master' into dict_construct
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Apr 8, 2022
2 parents d7d1511 + 8ec296c commit 43a53ce
Show file tree
Hide file tree
Showing 28 changed files with 1,023 additions and 189 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
- Bazel 4.2.1
- Libtorch 1.10.0 (built with CUDA 11.3)
- CUDA 11.3 (10.2 on Jetson)
- cuDNN 8.2
- TensorRT 8.0.3.4 (TensorRT 8.0.1.6 on Jetson)
- cuDNN 8.2.1
- TensorRT 8.2.4.2 (TensorRT 8.2.1 on Jetson)

## Prebuilt Binaries and Wheel files

Expand Down
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ http_archive(
http_archive(
name = "tensorrt",
build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "da130296ac6636437ff8465812eb55dbab0621747d82dc4fe9b9376f00d214af",
strip_prefix = "TensorRT-8.2.2.1",
sha256 = "826180eaaecdf9a7e76116855b9f1f3400ea9b06e66b06a3f6a0747ba6f863ad",
strip_prefix = "TensorRT-8.2.4.2",
urls = [
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.2.1/tars/tensorrt-8.2.2.1.linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz",
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.4/tars/tensorrt-8.2.4.2.linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz",
],
)

Expand Down
13 changes: 6 additions & 7 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,10 @@ void MapInputsAndDetermineDTypes(
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
}
// Overwrite type map with user settings
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
}
} else {
// The user defined the type so no changes are necessary
Expand Down Expand Up @@ -417,18 +418,16 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
if (cfg.partition_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 &&
conversion::VerifyConverterSupportForBlock(g->block(), true))) {
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled &&
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 &&
conversion::VerifyConverterSupportForBlock(g->block(), true))) {
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
new_g = graph_and_mapping.first;
Expand Down
24 changes: 17 additions & 7 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
// Node input has not been converted yet or is a prim op
TORCHTRT_THROW_ERROR(
"Unable to retrieve all node inputs for node: "
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: " << *input_node);
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: %"
<< input->debugName());
}
}

Expand Down Expand Up @@ -533,18 +534,22 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
if (unsupported_ops.size() != 0) {
std::stringstream unsupported_msg;
unsupported_msg
<< "Method requested cannot be compiled by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
<< "Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
<< std::endl;
for (auto s : unsupported_ops) {
unsupported_msg << " - " << s.second << std::endl;
}
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
<< std::endl;
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
unsupported_msg << std::endl << "In Module:" << std::endl;

if (!suppress_errors) {
unsupported_msg
<< "You can either implement converters for these ops in your application or request implementation"
<< std::endl;
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
unsupported_msg << std::endl << "In Module:" << std::endl;

LOG_ERROR(unsupported_msg.str());
} else {
LOG_INFO(unsupported_msg.str());
}

std::unordered_map<std::string, std::unordered_set<std::string>> unsupported_node_locations;
Expand All @@ -570,8 +575,13 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
for (const auto& str : type.second) {
traceback << str;
}

auto tb_str = traceback.str();
LOG_ERROR(tb_str);
if (!suppress_errors) {
LOG_ERROR(tb_str);
} else {
LOG_DEBUG(tb_str);
}
}

return false;
Expand Down
28 changes: 18 additions & 10 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,35 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);
auto options =
torch::TensorOptions().dtype(tensor_type).device(torch::kCUDA, ctx->settings.device.gpu_id);

torch::Tensor gamma, beta, mean, var;
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
// affine=True
LOG_DEBUG("Args[1] gamma : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
LOG_DEBUG("Args[2] beta : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
// track_running_stats=True
LOG_DEBUG("Args[3] mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
LOG_DEBUG("Args[4] var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);

auto channel_dim = shape[1];
if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
mean = args[3].unwrapToTensor(at::full(channel_dim, 0, options));
var = args[4].unwrapToTensor(at::full(channel_dim, 0, options));
}

auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));

LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);

// Expand spatial dims from 1D to 2D if needed
Expand Down
91 changes: 91 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,34 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
return true;
}

nvinfer1::ITensor* roll(
ConversionCtx* ctx,
nvinfer1::ITensor* in,
int shift,
int dim,
const std::vector<int64_t>& in_shape) {
auto in_dim = in_shape[dim];

auto start = (in_dim - shift) % in_dim;
// Behavior of % is different in C++ vs Python for negative numbers. This
// corrects the difference.
if (start < 0) {
start = start + in_dim;
}
at::Tensor index0 = at::arange(start, in_dim, 1, torch::kInt32);
at::Tensor index;
if (start == 0) {
index = index0;
} else {
at::Tensor index1 = at::arange(start, torch::kInt32);
index = at::cat({index0, index1}, 0);
}
auto index_tensor = tensor_to_const(ctx, index);
auto gather_layer = ctx->net->addGather(*in, *index_tensor, dim);
auto out = gather_layer->getOutput(0);
return out;
}

auto select_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
Expand Down Expand Up @@ -200,6 +228,69 @@ auto select_registrations TORCHTRT_UNUSED =

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}})
.pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto shifts = args[1].unwrapToIntList().vec();
auto dims = args[2].unwrapToIntList().vec();

TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()");
if (ctx->input_is_dynamic) {
TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation");
} else {
auto in_shape = util::toVec(in->getDimensions());
for (size_t i = 0; i < dims.size(); i++) {
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
in = roll(ctx, in, shifts[i], dim, in_shape);
}
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}
}})
.pattern(
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto ts = args[1].IValue()->toListRef();

std::vector<nvinfer1::ITensor*> tensors;
for (auto t : ts) {
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
tensors.push_back(tensor_to_const(ctx, torch_tensor));
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
}
}

// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
// indexes the self tensor along dimension 0.
TORCHTRT_CHECK(
tensors.size() == 1,
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
Expand Down
9 changes: 8 additions & 1 deletion core/conversion/converters/impl/stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].IValue()->toListRef();
auto dim = args[1].unwrapToInt();
if (-1 == dim) {
auto first_in = in[0];
if (first_in.isTensor()) {
dim = first_in.toTensor().ndimension();
} else {
dim = first_in.toCustomClass<TensorContainer>()->tensor()->getDimensions().nbDims;
}
}

std::vector<nvinfer1::ITensor*> tensors;

for (auto t : in) {
nvinfer1::ITensor* itensor;

Expand Down
13 changes: 13 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <math.h>

#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
Expand Down Expand Up @@ -98,6 +100,17 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
"aten::ge.float_int(float a, int b) -> (bool)",
}));

DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
pow,
"aten::pow",
pow(a, b),
std::set<std::string>({
"aten::pow.int(int a, int b) -> (float)",
"aten::pow.float(float a, float b) -> (float)",
"aten::pow.int_float(int a, float b) -> (float)",
"aten::pow.float_int(float a, int b) -> (float)",
}));

DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
and,
"aten::__and__",
Expand Down
47 changes: 47 additions & 0 deletions core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,53 @@
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_kind), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
if (args.at(n->input(0)).IValue()->isInt()) { \
auto a = args.at(n->input(0)).unwrapToInt(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
auto a = args.at(n->input(0)).unwrapToDouble(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
return {}; \
} \
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_name), \
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
passes::ReduceGelu(g);
passes::ReduceRemainder(g);
passes::RemoveContiguous(g);
passes::ViewToReshape(g);
passes::RemoveDropout(g);
Expand Down
Loading

0 comments on commit 43a53ce

Please sign in to comment.