Skip to content

Commit

Permalink
[BYOC][TensorRT] Fixes for explicit batch mode, Support reduce to sca…
Browse files Browse the repository at this point in the history
…lar, Support split op (apache#7967)
  • Loading branch information
Trevor Morris committed May 6, 2021
1 parent 86b74ff commit 2fd5ddf
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 9 deletions.
30 changes: 25 additions & 5 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _func_wrapper(expr):

def reduce_annotate_fn(attrs, args, op_name):
"""Helper for reduce operations."""
if not attrs.axis or len(attrs.axis) == 0:
if get_tensorrt_use_implicit_batch_mode() and (not attrs.axis or len(attrs.axis) == 0):
logger.info("%s: cannot reduce to scalar.", op_name)
return False
if attrs.exclude:
Expand Down Expand Up @@ -317,17 +317,18 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
for arg in args
]

# RelayVM + TRT doesn't support scalar addition yet.
for shape in shapes:
if len(shape) < 1:
return False
# Scalars require explicit batch mode.
if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
and len(shapes[0]) > 0
and len(shapes[1]) > 0
and shapes[0][0] == shapes[1][0]
and shapes[0][0] != 1
and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
Expand Down Expand Up @@ -552,6 +553,19 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("split")
def split_annotate_fn(expr):
"""Check if split is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("split: can't modify batch dimension.")
return False
return True


@_register_external_dynamic_check_func("nn.conv2d_transpose")
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d_transpose is supported by TensorRT."""
Expand Down Expand Up @@ -870,6 +884,11 @@ def visit_call(self, call):
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
"sum",
"prod",
"max",
"min",
"mean",
]
)
if isinstance(call.op, tvm.tir.op.Op):
Expand Down Expand Up @@ -968,6 +987,7 @@ def visit_call(self, call):
# Create new pruned module
new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod


Expand Down
31 changes: 31 additions & 0 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
SetPadNodeAttribute(node, cn);
} else if (name == "strided_slice") {
SetStridedSliceNodeAttribute(node, cn);
} else if (name == "split") {
SetSplitNodeAttribute(node, cn);
} else {
SetCallNodeAttribute(node, cn);
}
Expand Down Expand Up @@ -172,6 +174,35 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
node->SetAttr("strides", strides_attr);
}

void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
const auto* split_attr = cn->attrs.as<SplitAttrs>();
ICHECK(split_attr);

std::vector<std::string> indices_or_sections;
std::vector<std::string> mode;
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
mode.emplace_back("sections");
indices_or_sections.emplace_back(std::to_string(sections->value));
} else {
mode.emplace_back("indices");
auto indices = Downcast<tvm::Array<Integer>>(split_attr->indices_or_sections);
for (const auto& i : indices) {
indices_or_sections.emplace_back(std::to_string(i->value));
}
}

std::vector<dmlc::any> indices_or_sections_attr;
std::vector<dmlc::any> mode_attr;
std::vector<dmlc::any> axis_attr;
indices_or_sections_attr.emplace_back(indices_or_sections);
mode_attr.emplace_back(mode);
axis_attr.emplace_back(axis);
node->SetAttr("indices_or_sections", indices_or_sections_attr);
node->SetAttr("mode", mode_attr);
node->SetAttr("axis", axis_attr);
}

void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
auto ctx = transform::PassContext::Current();
auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
Expand Down
62 changes: 59 additions & 3 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,53 @@ class ConcatOpConverter : public TensorRTOpConverter {
}
};

class SplitOpConverter : public TensorRTOpConverter {
public:
SplitOpConverter() : TensorRTOpConverter({kTensor}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto input = params->inputs.at(0).tensor;
auto input_dims = TrtDimsToVector(input->getDimensions());
const int original_axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
const int axis = ConvertAxis(params, original_axis, input_dims.size());
auto indices_or_sections =
params->node.GetAttr<std::vector<std::string>>("indices_or_sections");
auto mode = params->node.GetAttr<std::vector<std::string>>("mode")[0];

std::vector<int> split_starts;
std::vector<int> split_sizes;
if (mode == "sections") {
int sections = std::stoi(indices_or_sections[0]);
int size = input_dims[axis] / sections;
for (int i = 0; i < sections; i++) {
split_starts.push_back(i * size);
split_sizes.push_back(size);
}
} else {
int last_index = 0;
for (size_t i = 0; i < indices_or_sections.size(); ++i) {
int index = std::stoi(indices_or_sections[i]);
split_starts.push_back(last_index);
split_sizes.push_back(index - last_index);
last_index = index;
}
split_starts.push_back(last_index);
split_sizes.push_back(input_dims[axis] - last_index);
}

std::vector<int> start(input_dims.size(), 0);
std::vector<int> size(input_dims.begin(), input_dims.end());
std::vector<int> strides(input_dims.size(), 1);
for (int i = 0; i < split_sizes.size(); ++i) {
start[axis] = split_starts[i];
size[axis] = split_sizes[i];
auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start),
VectorToTrtDims(size), VectorToTrtDims(strides));
params->outputs.push_back(slice_layer->getOutput(0));
}
}
};

class BiasAddOpConverter : public TensorRTOpConverter {
public:
BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
Expand Down Expand Up @@ -970,9 +1017,17 @@ class ReduceOpConverter : public TensorRTOpConverter {
// TODO(trevmorr): Support reduce to scalar.
ICHECK_GT(str_axis.size(), 0);
uint32_t reduce_axes = 0;
for (size_t i = 0; i < str_axis.size(); ++i) {
const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
reduce_axes |= 1 << axis;

if (str_axis.size() == 1 && str_axis[0].length() == 0) {
// Reduce to scalar
for (int i = 0; i < input->getDimensions().nbDims; ++i) {
reduce_axes |= 1 << i;
}
} else {
for (size_t i = 0; i < str_axis.size(); ++i) {
const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
reduce_axes |= 1 << axis;
}
}
auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims);
params->outputs.push_back(reduce_layer->getOutput(0));
Expand Down Expand Up @@ -1072,6 +1127,7 @@ GetOpConverters() {
map->emplace("expand_dims", std::make_shared<ExpandDimsOpConverter>());
map->emplace("squeeze", std::make_shared<SqueezeOpConverter>());
map->emplace("concatenate", std::make_shared<ConcatOpConverter>());
map->emplace("split", std::make_shared<SplitOpConverter>());
map->emplace("nn.conv2d_transpose", std::make_shared<Conv2DTransposeOpConverter>());
map->emplace("transpose", std::make_shared<TransposeOpConverter>());
map->emplace("layout_transform", std::make_shared<LayoutTransformOpConverter>());
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
* do nothing.
*/
void BuildEngine() {
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
batch_size_ =
data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0];
if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
Expand Down
13 changes: 13 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,19 @@ def get_graph(input_shapes, axis):
run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1))


def test_split():
def get_graph(x_shape, indices_or_sections, axis):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis)
f = relay.Function([x], out.astuple())
return f, {"x": x_shape}, []

run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1))
run_and_verify_func(get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1))


def test_conv2d_transpose():
def get_graph(
x_shape=(1, 32, 8, 8),
Expand Down

0 comments on commit 2fd5ddf

Please sign in to comment.