-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass #8069
[Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass #8069
Conversation
6aed619
to
86744b0
Compare
2c6c2c1
to
8bb3ad6
Compare
Thanks for the useful feature. Is this ready for review? |
Hey Animesh, it'll be ready for review soon. Probably by Monday morning (PST time). There's still some misc. improvements that should be made but I've decided to push those down for later PR's. |
8d02038
to
483dc29
Compare
This is ready for review |
""" | ||
return _ffi_api.AnnotateSpans() | ||
|
||
|
||
def RewriteFP16(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to call it AMPRewriter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the structure looks good. Will do a more detailed review later on. The only suggestion for now is to think about naming. Should we call it AMP? Later on we can reuse this for BF16
src/relay/transforms/fp32_to_fp16.h
Outdated
// GREEN colored ops should always be done in FP16 due to the speed and memory savings | ||
// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. | ||
// RED colored ops should not be done in FP16 due to numerical reasons. | ||
enum FP16ConversionCategory { RED, GRAY, GREEN }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if there is an op that is not associated with any of the colors? Is the default RED?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default it would be RED and a warning would be emitted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions:
- There are more like an attribute instead of category.
- Use straightforward terms, such as ALWAYS/FOLLOW/NEVER, instead of RED/GRAY/GREEN.
- Emit warning for non-specified ops may result in tedious messages. We could make it configurable to let users decide whether to print out these ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented the suggestions listed.
src/relay/transforms/fp32_to_fp16.h
Outdated
|
||
if (color == op_to_initial_color.end()) { | ||
if (ignore_missing) { | ||
LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the period at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/relay/transforms/fp32_to_fp16.h
Outdated
LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!."; | ||
return RED; | ||
} else { | ||
LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the period at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Can we get a few more initial reviews - @mbrookhart , @csullivan? @AndrewZhaoLuo I would also suggest to test a dynamic model like SSD or Mask-RCNN. Your current list of Object detection models involve Yolo which is static model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor changes for a first pass. I like the big idea of the pass, I want to dig into a couple of the details still, but overall, looking very good.
src/relay/transforms/fp32_to_fp16.cc
Outdated
auto h1 = std::hash<T1>()(pair.first); | ||
auto h2 = std::hash<T2>()(pair.second); | ||
|
||
return h1 ^ (h2 << 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, xor hash combine is pretty prone to hash conflicts? Maybe use the boost approach? return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/relay/transforms/fp32_to_fp16.cc
Outdated
|
||
public: | ||
explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func) | ||
: ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're using the recursive mutator here, you might run into stack overflows on larger models. I haven't looked at this pass in much detail yet, is it possible to do with a post-order traversal (or MixedMode Mutator?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes right now I depend on a post order traversal actually (since we want all arguments to call nodes to be mutated before we make a decision on whether to convert a call node to fp16). I'll look into MixedMode Mutator to solve this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
np.random.seed(90) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tianqi prefers we don't set random seeds to try to find intermittent bugs across CI runs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok that's an interesting idea. I had a failure where the passing rtol was 1.05e-5 so I'm just going to increase the tolerance.
I added tried it on an SSD model and it seems to work fine. Mask-RCNN I haven't found a spare file which can convert well and be run normally in FP32. |
dd03c23
to
391b15a
Compare
TF SSD is good enough. Thanks @AndrewZhaoLuo |
Thanks for this great PR! Would it be too much to ask for AMPRewrite and corresponding infra to support mixed precision with generic reduced precision floating point types? I notice the main assumption is to be downcasting to float16, though TVM has support for other reduced precision fp types for which mixed precision is useful e.g. float32 + bfloat16, as well as possible user defined floating point types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the RFC and PR. The overall idea LGTM, and I believe this would be an important feature in the future. Just have some concerns in the current implementation.
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger(): | |||
The registered SimplifyExpr pass. | |||
""" | |||
return _ffi_api.FakeQuantizationToInteger() | |||
|
|||
|
|||
def AMPRewrite(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the user API, we might need to think a bit more to make it more straightforward. For example, AMP
or AutoCast
are better naming IMHO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree. All the passes have names which are verbs which describe what they do while AMP
is a noun. Maybe AutoCast
would be better but it doesn't capture the mixed precision nature.
Maybe ToMixedPrecision
would be a better name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think AutoCast doesn't capture the nature. For example: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I would prefer ToMixedPrecision
still if that is fine with you.
The example you list only works for me because it exists under the amp
namespace. AutoCast
by itself without being part of torch.cuda.amp
does not show mixed precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine if no others complain about this naming.
src/relay/transforms/fp32_to_fp16.h
Outdated
// GREEN colored ops should always be done in FP16 due to the speed and memory savings | ||
// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. | ||
// RED colored ops should not be done in FP16 due to numerical reasons. | ||
enum FP16ConversionCategory { RED, GRAY, GREEN }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions:
- There are more like an attribute instead of category.
- Use straightforward terms, such as ALWAYS/FOLLOW/NEVER, instead of RED/GRAY/GREEN.
- Emit warning for non-specified ops may result in tedious messages. We could make it configurable to let users decide whether to print out these ops.
src/relay/transforms/fp32_to_fp16.h
Outdated
|
||
using OpStringSet = std::unordered_set<std::string>; | ||
|
||
// Default lists inspired from TF's classifications: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't prefer to specify op lists in a pass. It means we need to maintain this pass every time we add a new op. It would be better to follow the logic of other similar passes: Register an attribute to each op. If an op doesn't have this attribute registered, using the default behavior. It is also impossible for this implementation to accept user-defined rules from Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good advices.
I'll use better terms instead of RED/GRAY/GREEN.
I'll also make the warning messages configurable to the user.
For the registering attributes to each op, I think it's probably a good idea but do you have an example of this strategy I could look at?
The user defined rules from python is a goal I will try for. It might take a little longer though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can refer to the design document of layout conversion pass: https://tvm.apache.org/docs/dev/convert_layout.html. It's actually not hard to take rules from Python for this design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now done.
src/relay/transforms/fp32_to_fp16.cc
Outdated
// Determine the final color. | ||
FP16ConversionCategory final_color; | ||
if (initial_color == GRAY) { | ||
final_color = all_args_fp16_compatible ? GREEN : RED; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide an example of FP16 incompatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example with concat.
We have two branches whose outputs are fed into concat.
The first branch has a RED operation and returns an FP32 tensor.
The second branch returns an FP16 tensor.
Now that I say this, it might be better to be a bit smarter about GRAY ops when we have heterogeneous floating point types coming in.
E.g. let's say we had a concat with 10 fp16 args and 1 fp32 arg. It would be wasteful to default convert everything to fp32 and set the color as RED in this case.
I will change this so the number of fp16/fp32 args are taken into account. If there is a majority of fp16 or a tie we color GREEN else we color RED. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The workaround sounds fine to me. Again I'd suggest putting these op-specific heuristics to op attribute instead of this pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On closer thought I will lead things as is since only some ops will benefit from the trick I described. In the future exposing this to op-attributes might be worthwhile but I cannot think of a major savings that comes from this.
src/relay/transforms/fp32_to_fp16.cc
Outdated
Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype); | ||
Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span); | ||
if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) { | ||
output = CastArg(output, GetType(output), output_dtypes.output_dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this introduce unnecessary cast ops? For example, the accumulation dtype is FP32 and the followed op is RED. Will this make it A(GREEN) - cast_to_fp16 - cast_to_fp32 - B(RED)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I don't believe so since CachedCast
will also cache the reverse result.
E.g. CachedCast(A(Green), FP16)
would produce A(GREEN) - cast_to_fp16
But internally it would cache:
Node, wanted_dtype
A(GREEN), FP16
--> cast_to_fp16
cast_to_fp16, FP32
--> A(GREEN)
So attempting to cast cast_to_fp16
to fp32
would return A(GREEN)
It would be worth having a test case to cover this however and make sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see...this mechanism is interesting and I haven't paid too much attention on it. At the first glance, I would worry if the cache will blow up when the model is too large, but I'll probably take a deeper look at this mechanism later.
src/relay/transforms/fp32_to_fp16.h
Outdated
// TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. | ||
// Batched matmul has inconsistent support for mixed precision operations. | ||
// Many schedules ignore the out_dtype attribute which leads to errors when | ||
// input types do not match the out_dtype. Therefore, accumulate to fp16 if green. | ||
if (auto op_node = call->op.as<OpNode>()) { | ||
if (op_node->name == "nn.batch_matmul") { | ||
return {DataType::Float(16), DataType::Float(16)}; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This again illustrates the importance of registering casting function to each op instead of here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now functionally done in the python interface
Hey folks, covered the simple changes requested, here is the list of more involved changes along with the associated reviewer. Several of these changes were planned to be future PRs but it might be best to just commit this correctly the first time (since it doesn't really touch other files):
Let me know if I missed anything |
conv2d might take in fp16 and give a fp32 result. | ||
Attrs is const because we get it as a const. | ||
*/ | ||
T* mutable_attrs = const_cast<T*>(attrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we create and return a new attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@AndrewZhaoLuo I can suggest trying out DETR model. This is an interesting transformer based object detection model that consists exclusively of conv2d / matmul (no NMS). I believe this is a great fit for fp16 quantization. I have a tuning and benchmark script at https://github.com/masahi/torchscript-to-tvm/blob/master/detr/detr_test.py (should work with PT 1.7). I'm interested in both fp32 and fp16 performance on M1. I also have FasterRCNN, but it requires a good BLAS lib for reasonable performance (due to dynamic batch dense). So don't recommend it on M1. MaskRCNN is even worse, it has dynamic batch dense/conv2d/conv2_transpose. Another model that could be interesting is TF2 ssd mobilenet with combined NMS. Many people are interested in this variant of SSD and I have a model that cleanly exports to ONNX. Ask me if you want to try this out, I can give you the model and the script. |
@AndrewZhaoLuo oh by M1 do you mean its cpu or gpu (metal)? |
I think it's CPU. Here's the benchmarking + tuning script I used: https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/a3c4b6b2235afb1826b237af1136bbb9539c9ff9/fp16_pass/benchmark_m1_mac_fp16.py The other models you have are interesting, I think the SSD model I used has combined NMS. At least, it returns variable length tensors representing different numbers of objects detected. |
I see, interesting to see that fp16 makes things faster on CPU. So you've tested only on LLVM? Does this work on Later I can test it on CUDA (tensorcore) and OpenCL (intel), and hopefully @Lunderberg for vulkan. |
Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype); | ||
cast_nodes_cache[{expr_node, wanted_dtype}] = result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
Consider the op A (out: fp32, want: fp16)
, the cache will look like the following after processing A's output:
(A, fp16): cast
(cast, fp32): A
Now consider the followed op B
:
Case 1. If B
wants fp32, then like you mentioned before, we query (cast, fp32)
and get A
, so it becomes A -> B
.
Case 2. If B
wants fp16, then we query (cast, fp16)
, which is missed and a new entry (cast, fp16): cast
is created and returned, so it becomes A -> cast -> B
.
This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
- Apparently, the cache entry
(cast, fp16): cast
in the example is not necessary. I think we can simply returnexpr
whenexpr_dtype == wanted_dtype
? - The created
cast
ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
Another direction I would actually recommend is removing the cache and letting this pass generate cast ops as many as it wants, and we run SimplifyExpr pass afterward to cancel back-to-back cast ops (ref: #8081). IIUC, this should generate the same IR as the current pass, so it doesn't hurt the final performance (please correct me if I missed something).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Right now that is what functionally happens with this line
Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
It still creates a cache entry though so I reorganized it to be clearer and not insert into the cache when expr_dtype == wanted_dtype
- Hmm I believe creating the op lazily will not have any benefit. This is because there aren't any useless casts e.g. refer to 1.
The idea of having another pass handle back to back casts is appealing as the tool can be used in many other situations. The main concern I have is about correctness, e.g. does it handle weird edge cases well? I'll take a closer look at the existing PR and think a little more about this.
I do agree that this is a better direction to go however and will refactor the pass when a sufficiently correct cast-folding pass exists and is checked into main.
The Metal backend support fp16. And as far as I know @elvin-n have run fp16 models with our Metal backend and collected some performance metrics. I think he'll add some information about it. What about M1, we didn't try to run fp16 models on Metal on M1 yet. Theoretically, it should work, but we should check it. |
@anijain2305 PTAL. I believe I've addressed all the major points. |
We have not tried to tune Metal on M1 yet. Have you tried with AutTVM or AutoScheduler? |
No, we didn't try it. I'll take a look on it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have other major comments. I think as the first PR of the AMP support, this PR is basically ready to be merged once the comments from @mbrookhart are addressed. The other improvements can be done in the follow-up PRs.
@mbrookhart PTAL. I'm going to push ADT support down to a future PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this, we have a couple of TODOs but I think the core of it is in a great place.
@masahi @anijain2305 Any more comments? Otherwise I'll plan to merge later this afternoon.
Currently, I can run all the tests in On the cuda side, it's failing a check that requires 16-bit floats to be used in pairs.
On the vulkan side, it's something similar with the validation checks failing an alignment rule.
I don't think either of these are reasons not to merge, and I've added the vulkan errors to my todo list for the ongoing |
Thanks @AndrewZhaoLuo for the great work, and everyone for reviews!! I'll follow up with CUDA and OpenCL support. |
Added some tracking issues for CUDA and Vulkan:
|
HI, I tried this: def compile_model(mod, params, target, logfile, save_path):
tvm.relay.backend.compile_engine.get().clear()
mod = tvm.relay.transform.ToMixedPrecision(
mixed_precision_type='float16')(mod)
with tvm.autotvm.apply_history_best(logfile):
with tvm.transform.PassContext(opt_level=3):
lib = tvm.relay.build(mod, target=target, params=params)
lib.export_library(save_path) # 保存编译好的模型, 必须so结尾,不然c++不识别 But I got the error:
Did I miss any key point of using this feature ? |
|
Hi, I am not sure whether it is a usage question or the code can be refined, I am using a quite new commit pulled from github: And I built from the source following the steps in the doc website. Only changes in the
Should I still go to discussion website for help? |
Yes. Please go to the discuss forum. You can refer to this PR and tag relevant people in the post. |
…pache#8069) * Initial skeleton for fp16 pass. initial green gray and red lists move fp16 conversion to own fodler second pass example split up files a bit more cool nodes bro initial transofmr pass * Working python version of fp16 pass. fix topi conv2d not casting kernel to output type working resnet, but conv2d topi intrinsics need work tests for resnet add more tests, extend coverage for converter update tests, ensure red ops convert back to fp32 clean up code a bit simplify fp16 output dtype examination fix pass update tests initial coloring * Rewrite python passes in C++ inspect arg fields add propagate colors pass" private -> public inheritance" rewrite draft full transformation in c++ remove prints fp16 pass the proper wrapping insert extra cast to pass type checking fix previously broken test by removing cast in wrong scenario remove old python_files * Extend support to things besides CallNodes. E.g. tuples and lets fp32 invalidate typing instead of cast adding basic tests skeleton code out Stash work -- casting based on checked types working let statements add more ops, handle functions more generally add multiply, fix broken case support TupleNodes properly, move hash function for datatypes into data_type.h" update simple let test with structural expectation cleanup p1 remove old file * Rewrite how and when casting is done by checking types directly. add support for GPT2, BERT add some more comments new single pass version formatting make a lot of things const references clean up tests more cleanup more comments final comment add newline * linting and formatting * add AST header * remove todo * lint errors2 * remove i386 incompatible features * Trigger CI again * set seed * lint * address animesh's initial comments * mutate attributes only if they were originally floats * initial comments from matthew * add comment on hashing strat * add missing ; * edge case when mutating attrs * Cody's easy to address comments * add test to show green-red casting works * remove np.random seed from each test * remove as many references to fp16 types in favor of generic mixed types * rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc. * skeleton for supporting arbitrary mixed types * cool tests * Using MixedModeMutator * rename things ToMixedPrecision * rename passes to amp.cc * rename tests to match transform * clean up typos * rename even better to_mixed_precision * don't insert into cache when dtypes equal * new python interface for registering ops * cleaner registering ops * add fp64 structural test * clean up and comments * make copy of attributes * asf header * pylint * remove TODO which is solved * Apply nits from code review (comaniac) Co-authored-by: Cody Yu <comaniac0422@gmail.com> * change cast_node_cache --> cast_node_cache_ * add check for returned vals * better error msg * docstring for pass in python * fix default behavior to be proper * better error reporting via single flag * priority to 0 * address more nits * fix story telling slightly * restart * correct docstring * change class fields to have _ at end * add class docstring * add comment on accumulation dtype hack * ADT warnings * add todo * fix linter Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Please do not ask questions in the PR directly. Weights have cast ops because they are parameters instead of constants. You have to bind parameters first, run ToMixPercision, and run FoldConstant to remove casts. |
…pache#8069) * Initial skeleton for fp16 pass. initial green gray and red lists move fp16 conversion to own fodler second pass example split up files a bit more cool nodes bro initial transofmr pass * Working python version of fp16 pass. fix topi conv2d not casting kernel to output type working resnet, but conv2d topi intrinsics need work tests for resnet add more tests, extend coverage for converter update tests, ensure red ops convert back to fp32 clean up code a bit simplify fp16 output dtype examination fix pass update tests initial coloring * Rewrite python passes in C++ inspect arg fields add propagate colors pass" private -> public inheritance" rewrite draft full transformation in c++ remove prints fp16 pass the proper wrapping insert extra cast to pass type checking fix previously broken test by removing cast in wrong scenario remove old python_files * Extend support to things besides CallNodes. E.g. tuples and lets fp32 invalidate typing instead of cast adding basic tests skeleton code out Stash work -- casting based on checked types working let statements add more ops, handle functions more generally add multiply, fix broken case support TupleNodes properly, move hash function for datatypes into data_type.h" update simple let test with structural expectation cleanup p1 remove old file * Rewrite how and when casting is done by checking types directly. add support for GPT2, BERT add some more comments new single pass version formatting make a lot of things const references clean up tests more cleanup more comments final comment add newline * linting and formatting * add AST header * remove todo * lint errors2 * remove i386 incompatible features * Trigger CI again * set seed * lint * address animesh's initial comments * mutate attributes only if they were originally floats * initial comments from matthew * add comment on hashing strat * add missing ; * edge case when mutating attrs * Cody's easy to address comments * add test to show green-red casting works * remove np.random seed from each test * remove as many references to fp16 types in favor of generic mixed types * rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc. * skeleton for supporting arbitrary mixed types * cool tests * Using MixedModeMutator * rename things ToMixedPrecision * rename passes to amp.cc * rename tests to match transform * clean up typos * rename even better to_mixed_precision * don't insert into cache when dtypes equal * new python interface for registering ops * cleaner registering ops * add fp64 structural test * clean up and comments * make copy of attributes * asf header * pylint * remove TODO which is solved * Apply nits from code review (comaniac) Co-authored-by: Cody Yu <comaniac0422@gmail.com> * change cast_node_cache --> cast_node_cache_ * add check for returned vals * better error msg * docstring for pass in python * fix default behavior to be proper * better error reporting via single flag * priority to 0 * address more nits * fix story telling slightly * restart * correct docstring * change class fields to have _ at end * add class docstring * add comment on accumulation dtype hack * ADT warnings * add todo * fix linter Co-authored-by: Cody Yu <comaniac0422@gmail.com>
This implements a pass to convert an fp32 relay graph into an fp16 version. The RFC is described here.
Changes:
Testing
Some unittests.
Models tested (onnx):
Image Classification
Object Detection
Embedding Models
Super resolution:
NLP:
Models tested (relay native):
By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.
Future PRs (in order of priority)
Known issues
cc @mbrookhart , @csullivan please take a look and add relevant reviewers
Speedups (add as I go along)
BERT w/ input shape [1, 128] on M1 Mac (based on https://github.com/octoml/Apple-M1-BERT) and 10000 tuning trials:
FP32 version - Mean inference time (std dev): 107.82 ms (3.39 ms)
FP16 version - Mean inference time (std dev): 80.04 ms (6.19 ms)
~25% speedup!
Yolov2 (https://github.com/onnx/models) w/ 10000 tuning trials on M1 Mac
FP32 version - Mean inference time (std dev): 112.21 (3.75 ms)
FP16 version - Mean inference time (std dev): 71.05 ms (4.04 ms)
~36% speedup!