Skip to content

Commit

Permalink
[TENSORRT] Improvements and fixes for TensorRT (#11203)
Browse files Browse the repository at this point in the history
A number of small fixes and refactors to improve the robustness of
the TensorRT integration.

Co-authored-by: Mark Shields <mbs@octoml.ai>

Co-authored-by: Mark Shields <mbs@octoml.ai>
  • Loading branch information
mbaret and mbs-octoml authored May 10, 2022
1 parent 8d4f4dd commit be2ae94
Show file tree
Hide file tree
Showing 12 changed files with 901 additions and 1,105 deletions.
1,123 changes: 511 additions & 612 deletions python/tvm/relay/op/contrib/tensorrt.py

Large diffs are not rendered by default.

178 changes: 124 additions & 54 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,51 +70,28 @@ class TensorRTCompilerConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig);

using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
using OpAttrExtractor = backend::contrib::OpAttrExtractor;
using JSONSerializer = backend::contrib::JSONSerializer;

class TensorRTJSONSerializer;

/*!
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
* runtime.
* \brief Collect the constants and attributes from all operator calls in the body
* of a "Composite" function.
*/
class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;

class CollectFromCompositeFunctionBody : public ExprVisitor {
public:
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
: JSONSerializer(symbol, expr) {}

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) {
std::string name;
if (const auto* op_node = cn->op.as<OpNode>()) {
name = op_node->name;
} else {
return JSONSerializer::VisitExpr_(cn);
}
explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
: serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}

std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : cn->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
if (name == "nn.pad") {
SetPadNodeAttribute(node, cn);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(node, cn);
} else if (name == "split") {
SetSplitNodeAttribute(node, cn);
} else {
SetCallNodeAttribute(node, cn);
}
// These attributes are global to the whole module.
SaveGlobalAttributes(node);
return AddNode(node, GetRef<Expr>(cn));
}
void VisitExpr_(const ConstantNode* constant_node) final;
void VisitExpr_(const CallNode* call_node) final;

void SetPadNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* pad_attr = cn->attrs.as<PadAttrs>();
void SetPadNodeAttribute(const CallNode* call_node) {
const auto* pad_attr = call_node->attrs.as<PadAttrs>();
ICHECK(pad_attr);
auto p = pad_attr->pad_width;
const int dim_h = (p.size() == 5) ? 3 : 2;
Expand All @@ -125,16 +102,16 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
std::to_string(p[dim_w][1].as<IntImmNode>()->value)};
std::vector<dmlc::any> padding_attr;
padding_attr.emplace_back(padding);
node->SetAttr("padding", padding_attr);
node_->SetAttr("padding", padding_attr);
}

void SetStridedSliceNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* attrs = cn->attrs.as<StridedSliceAttrs>();
void SetStridedSliceNodeAttribute(const CallNode* call_node) {
const auto* attrs = call_node->attrs.as<StridedSliceAttrs>();
ICHECK(attrs && attrs->begin && attrs->end && attrs->strides)
<< "StridedSlice must have static begin, end, and strides.";
const bool default_strides =
!attrs->strides.value().defined() || attrs->strides.value().size() == 0;
auto ishape = backend::GetShape(cn->args[0]->checked_type());
auto ishape = backend::GetShape(call_node->args[0]->checked_type());

auto process_slice_index = [](Integer x, int default_value, int dim_value) {
if (!x.defined()) return default_value;
Expand Down Expand Up @@ -173,19 +150,19 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
start_attr.emplace_back(start);
size_attr.emplace_back(size);
strides_attr.emplace_back(strides);
node->SetAttr("start", start_attr);
node->SetAttr("size", size_attr);
node->SetAttr("strides", strides_attr);
node_->SetAttr("start", start_attr);
node_->SetAttr("size", size_attr);
node_->SetAttr("strides", strides_attr);
}

void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* split_attr = cn->attrs.as<SplitAttrs>();
void SetSplitNodeAttribute(const CallNode* call_node) {
const auto* split_attr = call_node->attrs.as<SplitAttrs>();
ICHECK(split_attr);

std::vector<std::string> indices_or_sections;
std::vector<std::string> mode;
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
if (const auto* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
mode.emplace_back("sections");
indices_or_sections.emplace_back(std::to_string(sections->value));
} else {
Expand All @@ -202,12 +179,80 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
indices_or_sections_attr.emplace_back(indices_or_sections);
mode_attr.emplace_back(mode);
axis_attr.emplace_back(axis);
node->SetAttr("indices_or_sections", indices_or_sections_attr);
node->SetAttr("mode", mode_attr);
node->SetAttr("axis", axis_attr);
node_->SetAttr("indices_or_sections", indices_or_sections_attr);
node_->SetAttr("mode", mode_attr);
node_->SetAttr("axis", axis_attr);
}

void SetGenericAttributes(const CallNode* call_node) {
OpAttrExtractor extractor(node_);
const Object* attr_obj = call_node->attrs.get();
extractor.Extract(const_cast<Object*>(attr_obj));
}

void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
TensorRTJSONSerializer* serializer_;
/*! \brief Accumulated translated arguments. */
std::vector<JSONGraphNodeEntry> args_;
/*!
* \brief Temporary node into which we'll accumulate attributes. Ideally this would be the
* final JSONGraphNode however we don't yet know how many inputs that will have.
*/
JSONGraphObjectPtr node_;
};

/*!
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
* runtime.
*/
class TensorRTJSONSerializer : public JSONSerializer {
public:
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
: JSONSerializer(symbol, expr) {}

using JSONSerializer::VisitExpr_;

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
// The call must be to an inline "Composite" function
const auto* function_node = call_node->op.as<FunctionNode>();
ICHECK(function_node != nullptr);
auto opt_composite = function_node->GetAttr<String>(attr::kComposite);
ICHECK(opt_composite.defined());
std::string name = opt_composite.value();

// Collect the constants and attributes of all operator calls inside the composite body.
CollectFromCompositeFunctionBody collector(this);
collector.VisitExpr(function_node->body);

// Capture the args to the "Composite" function as inputs for this node.
std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : call_node->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}

// Capture constants from the composite function body as additional inputs for this node.
for (const auto& node : collector.args_) {
inputs.emplace_back(node);
}

// Create the final node.
auto node = std::make_shared<JSONGraphNode>(name,
/*op_type=*/"kernel", inputs,
/*num_output=*/1);

// Transfer attributes from the collector's node to the final node.
node->CaptureAttrs(*collector.node_);

// Capture global settings on the JSON node.
SaveGlobalAttributes(node);

VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";

return AddNode(node, GetRef<Expr>(call_node));
}

static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
if (!cfg.defined()) {
Expand Down Expand Up @@ -236,6 +281,28 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
}
};

void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) {
for (const auto& entry : serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
args_.emplace_back(entry);
}
}

void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
const auto* op_node = call_node->op.as<OpNode>();
ICHECK(op_node != nullptr);
std::string name = op_node->name;
if (name == "nn.pad") {
SetPadNodeAttribute(call_node);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(call_node);
} else if (name == "split") {
SetSplitNodeAttribute(call_node);
} else {
SetGenericAttributes(call_node);
}
ExprVisitor::VisitExpr_(call_node);
}

/*!
* \brief Create a runtime module for TensorRT.
* \param ref The ext_func Relay expression/module to be executed using extern ops.
Expand All @@ -246,12 +313,15 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
Function func = Downcast<Function>(ref);
std::string func_name = backend::GetExtSymbol(func);

VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
TensorRTJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
auto param_names = serializer.GetParams();
const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function.";
VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
return lib;
}
Expand Down
94 changes: 0 additions & 94 deletions src/relay/transforms/inline_composites.cc

This file was deleted.

6 changes: 6 additions & 0 deletions src/runtime/contrib/json/json_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ class JSONGraphNode {
*/
bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }

void CaptureAttrs(const JSONGraphNode& that) {
for (const auto& kv : that.attrs_) {
attrs_[kv.first] = kv.second;
}
}

virtual ~JSONGraphNode() {}

private:
Expand Down
Loading

0 comments on commit be2ae94

Please sign in to comment.