Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TENSORRT] Improvements and fixes for TensorRT #11203

Merged
merged 1 commit into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,123 changes: 511 additions & 612 deletions python/tvm/relay/op/contrib/tensorrt.py

Large diffs are not rendered by default.

178 changes: 124 additions & 54 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,51 +70,28 @@ class TensorRTCompilerConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig);

using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
using OpAttrExtractor = backend::contrib::OpAttrExtractor;
using JSONSerializer = backend::contrib::JSONSerializer;

class TensorRTJSONSerializer;

/*!
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
* runtime.
* \brief Collect the constants and attributes from all operator calls in the body
* of a "Composite" function.
*/
class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;

class CollectFromCompositeFunctionBody : public ExprVisitor {
public:
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
: JSONSerializer(symbol, expr) {}

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) {
std::string name;
if (const auto* op_node = cn->op.as<OpNode>()) {
name = op_node->name;
} else {
return JSONSerializer::VisitExpr_(cn);
}
explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
: serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}

std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : cn->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
if (name == "nn.pad") {
SetPadNodeAttribute(node, cn);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(node, cn);
} else if (name == "split") {
SetSplitNodeAttribute(node, cn);
} else {
SetCallNodeAttribute(node, cn);
}
// These attributes are global to the whole module.
SaveGlobalAttributes(node);
return AddNode(node, GetRef<Expr>(cn));
}
void VisitExpr_(const ConstantNode* constant_node) final;
void VisitExpr_(const CallNode* call_node) final;

void SetPadNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* pad_attr = cn->attrs.as<PadAttrs>();
void SetPadNodeAttribute(const CallNode* call_node) {
const auto* pad_attr = call_node->attrs.as<PadAttrs>();
ICHECK(pad_attr);
auto p = pad_attr->pad_width;
const int dim_h = (p.size() == 5) ? 3 : 2;
Expand All @@ -125,16 +102,16 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
std::to_string(p[dim_w][1].as<IntImmNode>()->value)};
std::vector<dmlc::any> padding_attr;
padding_attr.emplace_back(padding);
node->SetAttr("padding", padding_attr);
node_->SetAttr("padding", padding_attr);
}

void SetStridedSliceNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* attrs = cn->attrs.as<StridedSliceAttrs>();
void SetStridedSliceNodeAttribute(const CallNode* call_node) {
const auto* attrs = call_node->attrs.as<StridedSliceAttrs>();
ICHECK(attrs && attrs->begin && attrs->end && attrs->strides)
<< "StridedSlice must have static begin, end, and strides.";
const bool default_strides =
!attrs->strides.value().defined() || attrs->strides.value().size() == 0;
auto ishape = backend::GetShape(cn->args[0]->checked_type());
auto ishape = backend::GetShape(call_node->args[0]->checked_type());

auto process_slice_index = [](Integer x, int default_value, int dim_value) {
if (!x.defined()) return default_value;
Expand Down Expand Up @@ -173,19 +150,19 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
start_attr.emplace_back(start);
size_attr.emplace_back(size);
strides_attr.emplace_back(strides);
node->SetAttr("start", start_attr);
node->SetAttr("size", size_attr);
node->SetAttr("strides", strides_attr);
node_->SetAttr("start", start_attr);
node_->SetAttr("size", size_attr);
node_->SetAttr("strides", strides_attr);
}

void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* split_attr = cn->attrs.as<SplitAttrs>();
void SetSplitNodeAttribute(const CallNode* call_node) {
const auto* split_attr = call_node->attrs.as<SplitAttrs>();
ICHECK(split_attr);

std::vector<std::string> indices_or_sections;
std::vector<std::string> mode;
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
if (const auto* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
mode.emplace_back("sections");
indices_or_sections.emplace_back(std::to_string(sections->value));
} else {
Expand All @@ -202,12 +179,80 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
indices_or_sections_attr.emplace_back(indices_or_sections);
mode_attr.emplace_back(mode);
axis_attr.emplace_back(axis);
node->SetAttr("indices_or_sections", indices_or_sections_attr);
node->SetAttr("mode", mode_attr);
node->SetAttr("axis", axis_attr);
node_->SetAttr("indices_or_sections", indices_or_sections_attr);
node_->SetAttr("mode", mode_attr);
node_->SetAttr("axis", axis_attr);
}

void SetGenericAttributes(const CallNode* call_node) {
OpAttrExtractor extractor(node_);
const Object* attr_obj = call_node->attrs.get();
extractor.Extract(const_cast<Object*>(attr_obj));
}

void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
TensorRTJSONSerializer* serializer_;
/*! \brief Accumulated translated arguments. */
std::vector<JSONGraphNodeEntry> args_;
/*!
* \brief Temporary node into which we'll accumulate attributes. Ideally this would be the
* final JSONGraphNode however we don't yet know how many inputs that will have.
*/
JSONGraphObjectPtr node_;
};

/*!
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
* runtime.
*/
class TensorRTJSONSerializer : public JSONSerializer {
public:
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
: JSONSerializer(symbol, expr) {}

using JSONSerializer::VisitExpr_;

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
// The call must be to an inline "Composite" function
const auto* function_node = call_node->op.as<FunctionNode>();
ICHECK(function_node != nullptr);
auto opt_composite = function_node->GetAttr<String>(attr::kComposite);
ICHECK(opt_composite.defined());
std::string name = opt_composite.value();

// Collect the constants and attributes of all operator calls inside the composite body.
CollectFromCompositeFunctionBody collector(this);
collector.VisitExpr(function_node->body);

// Capture the args to the "Composite" function as inputs for this node.
std::vector<JSONGraphNodeEntry> inputs;
for (const auto& arg : call_node->args) {
auto res = VisitExpr(arg);
inputs.insert(inputs.end(), res.begin(), res.end());
}

// Capture constants from the composite function body as additional inputs for this node.
for (const auto& node : collector.args_) {
inputs.emplace_back(node);
}

// Create the final node.
auto node = std::make_shared<JSONGraphNode>(name,
/*op_type=*/"kernel", inputs,
/*num_output=*/1);

// Transfer attributes from the collector's node to the final node.
node->CaptureAttrs(*collector.node_);

// Capture global settings on the JSON node.
SaveGlobalAttributes(node);

VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";

return AddNode(node, GetRef<Expr>(call_node));
}

static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
if (!cfg.defined()) {
Expand Down Expand Up @@ -236,6 +281,28 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
}
};

void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) {
for (const auto& entry : serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
args_.emplace_back(entry);
}
}

void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
const auto* op_node = call_node->op.as<OpNode>();
ICHECK(op_node != nullptr);
std::string name = op_node->name;
if (name == "nn.pad") {
SetPadNodeAttribute(call_node);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(call_node);
} else if (name == "split") {
SetSplitNodeAttribute(call_node);
} else {
SetGenericAttributes(call_node);
}
ExprVisitor::VisitExpr_(call_node);
}

/*!
* \brief Create a runtime module for TensorRT.
* \param ref The ext_func Relay expression/module to be executed using extern ops.
Expand All @@ -246,12 +313,15 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
Function func = Downcast<Function>(ref);
std::string func_name = backend::GetExtSymbol(func);

VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
TensorRTJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
auto param_names = serializer.GetParams();
const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function.";
VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
return lib;
}
Expand Down
94 changes: 0 additions & 94 deletions src/relay/transforms/inline_composites.cc

This file was deleted.

6 changes: 6 additions & 0 deletions src/runtime/contrib/json/json_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ class JSONGraphNode {
*/
bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }

void CaptureAttrs(const JSONGraphNode& that) {
for (const auto& kv : that.attrs_) {
attrs_[kv.first] = kv.second;
}
}

virtual ~JSONGraphNode() {}

private:
Expand Down
Loading