Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass #8069

Merged
merged 59 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
425471d
Initial skeleton for fp16 pass.
AndrewZhaoLuo May 14, 2021
2bd5311
Working python version of fp16 pass.
AndrewZhaoLuo May 17, 2021
9fda090
Rewrite python passes in C++
AndrewZhaoLuo May 19, 2021
4903a31
Extend support to things besides CallNodes. E.g. tuples and lets
AndrewZhaoLuo May 25, 2021
41ac568
Rewrite how and when casting is done by checking types directly.
AndrewZhaoLuo Jun 3, 2021
bde1c58
linting and formatting
AndrewZhaoLuo Jun 5, 2021
2101e6e
add AST header
AndrewZhaoLuo Jun 5, 2021
8e82c40
remove todo
AndrewZhaoLuo Jun 5, 2021
399121b
lint errors2
AndrewZhaoLuo Jun 5, 2021
c8f7428
remove i386 incompatible features
AndrewZhaoLuo Jun 5, 2021
42b0c04
Trigger CI again
AndrewZhaoLuo Jun 6, 2021
65b8d6c
set seed
AndrewZhaoLuo Jun 6, 2021
8860b1c
lint
AndrewZhaoLuo Jun 6, 2021
b3b8776
address animesh's initial comments
AndrewZhaoLuo Jun 7, 2021
479124b
mutate attributes only if they were originally floats
AndrewZhaoLuo Jun 8, 2021
22ae9e7
initial comments from matthew
AndrewZhaoLuo Jun 8, 2021
d956848
add comment on hashing strat
AndrewZhaoLuo Jun 9, 2021
cb39e0f
add missing ;
AndrewZhaoLuo Jun 9, 2021
a00fd8b
edge case when mutating attrs
AndrewZhaoLuo Jun 9, 2021
e25c40c
Cody's easy to address comments
AndrewZhaoLuo Jun 9, 2021
70436f5
add test to show green-red casting works
AndrewZhaoLuo Jun 9, 2021
2c78317
remove np.random seed from each test
AndrewZhaoLuo Jun 9, 2021
44b9782
remove as many references to fp16 types in favor of generic mixed types
AndrewZhaoLuo Jun 9, 2021
4911d4f
rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc.
AndrewZhaoLuo Jun 9, 2021
47c2cf8
skeleton for supporting arbitrary mixed types
AndrewZhaoLuo Jun 9, 2021
239dbfb
cool tests
AndrewZhaoLuo Jun 10, 2021
33e286f
Using MixedModeMutator
AndrewZhaoLuo Jun 10, 2021
418f873
rename things ToMixedPrecision
AndrewZhaoLuo Jun 10, 2021
7d62fe1
rename passes to amp.cc
AndrewZhaoLuo Jun 10, 2021
b4ebd06
rename tests to match transform
AndrewZhaoLuo Jun 10, 2021
8968cda
clean up typos
AndrewZhaoLuo Jun 10, 2021
180b556
rename even better to_mixed_precision
AndrewZhaoLuo Jun 10, 2021
528ef7b
don't insert into cache when dtypes equal
AndrewZhaoLuo Jun 14, 2021
5ca1462
new python interface for registering ops
AndrewZhaoLuo Jun 14, 2021
9e77cff
cleaner registering ops
AndrewZhaoLuo Jun 15, 2021
e691e4f
add fp64 structural test
AndrewZhaoLuo Jun 15, 2021
37200fd
clean up and comments
AndrewZhaoLuo Jun 15, 2021
4c93545
make copy of attributes
AndrewZhaoLuo Jun 15, 2021
6aa727d
asf header
AndrewZhaoLuo Jun 15, 2021
173801b
pylint
AndrewZhaoLuo Jun 15, 2021
f4da2df
remove TODO which is solved
AndrewZhaoLuo Jun 15, 2021
7698920
Apply nits from code review (comaniac)
AndrewZhaoLuo Jun 15, 2021
177f9c4
change cast_node_cache --> cast_node_cache_
AndrewZhaoLuo Jun 15, 2021
8ddabda
add check for returned vals
AndrewZhaoLuo Jun 15, 2021
78b5b31
better error msg
AndrewZhaoLuo Jun 15, 2021
54d7c3d
docstring for pass in python
AndrewZhaoLuo Jun 15, 2021
3331224
fix default behavior to be proper
AndrewZhaoLuo Jun 15, 2021
c781bf2
better error reporting via single flag
AndrewZhaoLuo Jun 15, 2021
b513fee
priority to 0
AndrewZhaoLuo Jun 15, 2021
4fea978
address more nits
AndrewZhaoLuo Jun 16, 2021
25d8a1d
fix story telling slightly
AndrewZhaoLuo Jun 16, 2021
a063994
restart
AndrewZhaoLuo Jun 16, 2021
22841f1
correct docstring
AndrewZhaoLuo Jun 17, 2021
7a933a5
change class fields to have _ at end
AndrewZhaoLuo Jun 17, 2021
a1dbb68
add class docstring
AndrewZhaoLuo Jun 17, 2021
97fbd89
add comment on accumulation dtype hack
AndrewZhaoLuo Jun 17, 2021
64408ee
ADT warnings
AndrewZhaoLuo Jun 17, 2021
98e9cea
add todo
AndrewZhaoLuo Jun 17, 2021
2634182
fix linter
AndrewZhaoLuo Jun 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,4 +389,19 @@ inline DLDataType String2DLDataType(std::string s) {
using DataType = runtime::DataType;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::DataType> {
inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
std::size_t operator()(tvm::DataType const& dtype) const {
int a = dtype.code();
int b = dtype.bits();
int c = dtype.lanes();
int d = cantor_pairing_function(a, b);
return cantor_pairing_function(c, d);
}
};
} // namespace std

#endif // TVM_RUNTIME_DATA_TYPE_H_
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
debug,
register_external_compiler,
register_fake_quantization_to_integer,
register_mixed_precision_conversion,
)
from . import strategy

Expand Down
33 changes: 30 additions & 3 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
"""The base node types for the Relay language."""
import tvm._ffi
import tvm.ir
from tvm.driver import lower, build
from tvm.target import get_native_generic_func, GenericFunc
from tvm.runtime import Object
import tvm.ir._ffi_api
from tvm.driver import build, lower
from tvm.runtime import Object
from tvm.target import GenericFunc, get_native_generic_func

from . import _make


Expand Down Expand Up @@ -457,6 +458,32 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


def register_mixed_precision_conversion(op_name, func=None, level=10):
"""Register mixed precision conversion function for an op

Given an op the function should return information on how the value should be
converted. Specifically the function should take a call node and the target
mixed precision datatype (e.g. FP16) and return the conversion category
(see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation
and output datatype of the operation in the mixed precision dtype space.

Parameters
----------
op_name : str
The name of the operator

func: function (call_node: relay.Call, target_dtype: string)
-> [conversion category, accumulation dtype, output dtype]: [int, string, string]
A function which given a call_node and target_dtype (e.g. FP16) returns the
conversion category and associated accumulation/output of the operation
when transformed into the mixed precision dtype space.

level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level)


@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
Expand Down
195 changes: 195 additions & 0 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long,unused-argument
"""Default behavior for ops in mixed_precision pass. Import this file to use."""
from typing import List

from tvm import relay
from tvm.relay.op import register_mixed_precision_conversion

# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
# numerical reasons.
MIXED_PRECISION_ALWAYS = 0
MIXED_PRECISION_FOLLOW = 1
MIXED_PRECISION_NEVER = 2

# Default lists inspired from TF's classifications:
# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
DEFAULT_ALWAYS_LIST = [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv1d_transpose",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
"nn.dense",
# "nn.batch_matmul", # Handled by a special case
]
DEFAULT_FOLLOW_LIST = [
# These ops add new data or change shape
"nn.pad",
"nn.batch_flatten",
"concatenate",
"zeros",
"split",
"squeeze",
"transpose",
"expand_dims",
"reshape",
"dyn.reshape",
"broadcast_to_like",
"dyn.broadcast_to",
"strided_slice",
"dyn.strided_slice",
"take",
"argwhere",
"where",
"tile",
"dyn.tile",
"scatter",
"full",
"dyn.full",
# Comparison
"less",
"greater",
"less_equal",
"greater_equal",
# By definition copy and cast will depend on inputs for output.
"copy",
"cast",
"cast_like",
# Simple arithmetic
"add",
"subtract",
"multiply",
"divide",
"nn.bias_add",
"nn.batch_norm",
"sum",
"mean",
"sqrt",
"shape_of",
# Simple activations
"max",
"min",
"maximum",
"minimum",
"nn.relu",
"nn.leaky_relu",
"nn.prelu",
"nn.dropout",
# Complicated activations which saturate in a narrow range
"sigmoid",
"tanh",
# Pooling operations
"nn.max_pool1d",
"nn.max_pool2d",
"nn.max_pool3d",
"nn.avg_pool1d",
"nn.avg_pool2d",
"nn.avg_pool3d",
# "nn.global_max_pool1d", # does not exist yet
"nn.global_max_pool2d",
# "nn.global_max_pool3d", # does not exist yet
# "nn.global_avg_pool1d", # does not exist yet
"nn.global_avg_pool2d",
# "nn.global_avg_pool3d", # does not exist yet
"nn.adaptive_max_pool1d",
"nn.adaptive_max_pool2d",
"nn.adaptive_max_pool3d",
"nn.adaptive_avg_pool1d",
"nn.adaptive_avg_pool2d",
"nn.adaptive_avg_pool3d",
]
DEFAULT_NEVER_LIST = [
# In general if |f(x)| >> |x| for expected inputs then put the op here.
"exp",
"power",
"nn.cross_entropy",
"nn.cross_entropy_with_logits",
"nn.softmax",
"nn.l2_normalize",
# Error function doesn't seem to be able to be lowered into fp16 version in llvm.
# Move to follow list when it does.
"erf",
]


# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
def register_func_to_op_list(list_ops: List):
def decorator(func):
for op_name in list_ops:
register_mixed_precision_conversion(op_name, func=func)

return decorator


def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
"""A function which returns output dtypes in a way which works for most ops.

Parameters
---------
call_node: relay.Call
The call node containing the op.
mixed_precision_type: str
The target type to run the operation in.
Returns
-------
output_dtypes : [str, str]
A list of two strings. The first represents the datatype used for accumulation
in the operation. The second represents the actual output datatype.
"""
# Assume support accumulation dtypes <---> has out_dtype attr.
# This is because there is no better way right now to tell which ops support accumulating
# at different data types.
# Some discussion here about making this better is here:
# https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo
if hasattr(call_node.attrs, "out_dtype"):
return ["float32", mixed_precision_type]

# [accumulation_dtype, output_dtype] for the operations
return [mixed_precision_type, mixed_precision_type]


# Functions for FTVMMixedPrecisionConversionType which
# Take in CallNodes and a DType and returns a conversion type,
# an accumulation dtype, and an output_dtype.
@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST)
def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST)
def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_mixed_precision_conversion("nn.batch_matmul")
def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List:
# TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
# Batched matmul has inconsistent support for mixed precision operations.
# Many schedules ignore the out_dtype attribute which leads to errors when
# input types do not match the out_dtype. Therefore, accumulate to output_dtype.
return [MIXED_PRECISION_ALWAYS, "float16", "float16"]
35 changes: 30 additions & 5 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
"""
Relay pass transformation infrastructure.
"""
import types
import inspect
import functools
import inspect
import types
import warnings

import tvm.ir
from tvm import te
from tvm import relay, te
from tvm.runtime import ndarray as _nd

from tvm import relay
from . import _ffi_api


Expand Down Expand Up @@ -1168,7 +1167,7 @@ def AnnotateSpans():
Returns
-------
ret : tvm.transform.Pass
The regsistered AnnotateSpans pass.
The registered AnnotateSpans pass.
"""
return _ffi_api.AnnotateSpans()

Expand Down Expand Up @@ -1199,3 +1198,29 @@ def FakeQuantizationToInteger():
The registered SimplifyExpr pass.
"""
return _ffi_api.FakeQuantizationToInteger()


def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
"""
Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
where as many operations as possible are in the target mixed_precision_type.
comaniac marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
mixed_precision_type: str
The target datatype to transform operations in the graph to use.

missing_op_mode: int
Determines how to handle ops not registered with FTVMMixedPrecisionConversionType
0: Does not allow any missing ops. Will throw errors when encountering any.
1: Allow missing ops but emit warnings.
2: Allow missing ops and silently ignore them.

Returns
-------
ret : tvm.transform.Pass
The registered pass.
"""
if missing_op_mode < 0 or missing_op_mode > 2:
raise ValueError("Missing op mode is either 0, 1, or 2")
return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode)
10 changes: 7 additions & 3 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
# pylint: disable=unused-argument, redefined-builtin
"""Conv2D operators"""
from __future__ import absolute_import as _abs

from collections import namedtuple

import tvm
from tvm import te, auto_scheduler
from tvm import auto_scheduler, te

from ..utils import get_const_int, get_const_tuple, simplify, tag
from .pad import pad
from .utils import get_pad_tuple
from ..utils import simplify, get_const_tuple, get_const_int, tag
from .winograd_util import winograd_transform_matrices

# workload description of conv2d
Expand Down Expand Up @@ -548,7 +550,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
ow * WSTR + kw * dilation_w,
idxmod(ic, ic_bn),
].astype(out_dtype)
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block],
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(
out_dtype
),
axis=[ic, kh, kw],
),
name="conv2d_NCHWc",
Expand Down
Loading