-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] A set of utilities that allows a model to be run efficiently on tensorcores. #6748
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ad0ac5f
HWNC layout conversion support and better tensorcore strategy checking.
jwfromm 7e71d4b
Add first draft at recast pass.
jwfromm b68f515
Layer count pass now working and tested.
jwfromm 92da625
Recast pass now working as expected.
jwfromm f8d2493
Recast tests added.
jwfromm 079d8f4
Formatting applied.
jwfromm b8710b7
Style fixes.
jwfromm 4f8e477
Another style fix.
jwfromm bb523f0
Merge branch 'main' into hwnc_tensorcore
jwfromm 0055631
Remove extra newline.
jwfromm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,6 @@ | |
# Feature | ||
from . import feature | ||
from . import sparse_dense | ||
|
||
# Utilities | ||
from .count_layers import count_layers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Utilities that enable counting the number of layers in a graph.""" | ||
import tvm | ||
from tvm import relay | ||
from ..expr_functor import ExprVisitor | ||
|
||
|
||
class LayerCounter(ExprVisitor): | ||
"""A visitor pass that computes the deepest chain of specified ops in graph.""" | ||
|
||
def __init__(self, valid_ops): | ||
self.depth_count = 0 | ||
self.deepest_count = 0 | ||
self.valid_ops = [relay.op.get(op) for op in valid_ops] | ||
super().__init__() | ||
|
||
def visit_call(self, call): | ||
if call.op in self.valid_ops: | ||
self.depth_count += 1 | ||
current_count = self.depth_count | ||
self.deepest_count = max(self.deepest_count, current_count) | ||
for arg in call.args: | ||
self.visit(arg) | ||
self.depth_count = current_count | ||
|
||
def count(self): | ||
return self.deepest_count | ||
|
||
|
||
def count_layers(expr, valid_ops): | ||
"""Determine the number of layers of specified ops in a graph. | ||
This pass computes only the deepest chain of ops rather than the | ||
total number of ops in a graph. Thus, if there are two parallel | ||
convolutions (for example), they would be considered a single layer. | ||
|
||
Parameters | ||
---------- | ||
expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule. | ||
The input expression. | ||
|
||
valid_ops: List[str] | ||
A list of the operations that should be included in the count. | ||
|
||
Returns | ||
------- | ||
layer_count : int | ||
The number of layers of the specified operations found in the graph. | ||
""" | ||
if isinstance(expr, tvm.ir.IRModule): | ||
expr = expr["main"] | ||
count_pass = LayerCounter(valid_ops) | ||
count_pass.visit(expr) | ||
return count_pass.count() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Relay type recasting pass""" | ||
import tvm | ||
from tvm import relay | ||
from tvm.ir import IRModule | ||
from .transform import InferType | ||
from ..analysis import count_layers | ||
from ..expr_functor import ExprMutator, Call | ||
|
||
|
||
class RecastMutator(ExprMutator): | ||
"""Cast operations to the target type.""" | ||
|
||
def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers): | ||
self.dtype = dtype | ||
self.out_dtype = out_dtype | ||
self.depth_count = 0 | ||
self.valid_ops = [relay.op.get(op) for op in valid_ops] | ||
self.valid_op_count = valid_op_count | ||
self.skip_layers = skip_layers | ||
# Convert negative indices to positive ones. | ||
for i, layer in enumerate(skip_layers): | ||
if layer < 0: | ||
skip_layers[i] = self.valid_op_count + layer | ||
super().__init__() | ||
|
||
def visit_call(self, call): | ||
# Keep track of our current depth and layer count | ||
# so we can know whether to skip this layer or not. | ||
current_depth = self.depth_count | ||
current_layer = self.valid_op_count - current_depth - 1 | ||
if call.op in self.valid_ops: | ||
self.depth_count += 1 | ||
# Visit current call operation | ||
new_fn = self.visit(call.op) | ||
# Visit current arguments | ||
args = [] | ||
for arg in call.args: | ||
args.append(self.visit(arg)) | ||
self.depth_count = current_depth | ||
|
||
# Downcast this op if its the correct type and not skipped. | ||
if call.op in self.valid_ops and current_layer not in self.skip_layers: | ||
# Recast inputs to specified type. | ||
args = [self.visit(arg) for arg in call.args] | ||
new_args = list() | ||
for arg in args: | ||
new_args.append(relay.cast(arg, dtype=self.dtype)) | ||
|
||
# If out_dtype is in the attributes, we need to update it. | ||
orig_dtype = None | ||
if "out_dtype" in call.attrs.keys(): | ||
new_attr_dict = {} | ||
for attr in call.attrs.keys(): | ||
attr_value = call.attrs[attr] | ||
if isinstance(attr_value, tvm.ir.container.Array): | ||
attr_value = tuple(attr_value) | ||
new_attr_dict[str(attr)] = attr_value | ||
new_attr_dict["out_dtype"] = self.out_dtype | ||
attr_type = str(call.attrs).split("(")[0] | ||
new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) | ||
if call.attrs["out_dtype"] != "": | ||
orig_dtype = call.attrs["out_dtype"] | ||
else: | ||
new_attrs = call.attrs | ||
|
||
if orig_dtype is None: | ||
# Perform type inference to determine the original type. | ||
new_mod = IRModule.from_expr(call) | ||
new_mod = InferType()(new_mod) | ||
checked_arg = new_mod["main"].body | ||
orig_dtype = checked_arg.checked_type.dtype | ||
# Recast the output for compatibility with other graph operations. | ||
return relay.cast(Call(new_fn, new_args, new_attrs), orig_dtype) | ||
|
||
# Otherwise return the unchanged call. | ||
return Call(new_fn, args, call.attrs) | ||
|
||
|
||
def recast(expr, dtype, out_dtype, ops=None, skip_layers=None): | ||
"""Convert the types of operations in a graph to a new value. | ||
Note that this is primarily useful for testing performance of individual | ||
operations at the new datatype. In a real setting, this pass will | ||
almost certainly do a poor job converting from one datatype to another | ||
as it just applies hard casting. For example, when recasting from float | ||
to integer, many small values will simply be set to 0. Although this will | ||
allow autotuning and benchmarking to produce proper timings at the new | ||
data type, the output of the model will of course be heavily impacted. | ||
|
||
Parameters | ||
--------- | ||
expr: tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule | ||
The original function that will have its type changed. | ||
dtype: str | ||
The target type to cast to. | ||
out_dtype: str | ||
The output type to cast to. | ||
ops: List[str] | ||
A list of operations that should have their type changed, | ||
others will be left as is. | ||
skip_layers: List[int] | ||
A list of integers indicating operations that should | ||
not have their type changed, counted starting with the | ||
first valid operation encountered. Negative indices are | ||
allowed and indicate starting at the last layer. | ||
Returns | ||
------- | ||
output_expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule | ||
The graph after recasting to the specified datatype. | ||
""" | ||
return_mod = False | ||
if isinstance(expr, tvm.ir.IRModule): | ||
expr = expr["main"] | ||
return_mod = True | ||
if ops is None: | ||
ops = ["nn.conv2d"] | ||
if skip_layers is None: | ||
skip_layers = [] | ||
layer_depth = count_layers(expr, ops) | ||
recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) | ||
expr = recast_pass.visit(expr) | ||
if return_mod: | ||
return tvm.IRModule.from_expr(expr) | ||
return expr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from tvm.relay.testing import resnet | ||
from tvm.relay.analysis import count_layers | ||
|
||
|
||
def test_layer_count(): | ||
def verify(num_layers): | ||
# Load a resnet with a known number of layers. | ||
mod, _ = resnet.get_workload(num_layers=num_layers) | ||
# Count the number of conv and dense layers. | ||
count = count_layers(mod, valid_ops=["nn.conv2d", "nn.dense"]) | ||
assert count == num_layers | ||
|
||
verify(18) | ||
verify(50) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_layer_count() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When have you found it useful to skip a specific layer of a given operator type / how do you envision it being used? Mainly for debugging and performance tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, the first layer of most networks does not have a sufficient number of channels for our tensorcore schedules to be applied. Although this would in theory not be a problem, there aren't HWNC schedules for GPU. So if you blindly apply ConvertLayout to all layers, you end up with a first layer that cant be executed. Skipping it during conversion is an elegant way to avoid this issue. I imagine a similar pathology could apply to other situations.