Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorCore Support using Intrinsic #4136

Merged
merged 12 commits into from
Oct 24, 2019
Merged

TensorCore Support using Intrinsic #4136

merged 12 commits into from
Oct 24, 2019

Conversation

Hzfengsy
Copy link
Member

This is the code for the RFC #4052 with a TensorCore tutorial
Any comments and reviews are welcomed.

@Hzfengsy
Copy link
Member Author

@minminsun Can you help to review this?

@Laurawly
Copy link
Contributor

Can we also add int support in this pr? Also is there any performance updates compared with the ones mentioned in RFC #4052?

@Hzfengsy
Copy link
Member Author

@Laurawly Thank you for your comments. It is my negligence forgetting to support int type. But one thing worths noting is that int type TensorCores are only supported on Turing GPUs. I can't access any Turing GPUs. It should work but need more testing.

As for the performance. Currently, we are faster than CUBLAS/CUDNN on small shapes but slower on large shapes. In fact, the performance depends on the scheduling but not the tvm implement. I believe that I have the same performance as Minmin's. Even the CUDA code is the same. I have to admit there should be a better schedule on specific shapes. Any better schedule is welcomed. However, I believe that it should be done by autotvm rather than human beings.

@minminsun
Copy link
Contributor

@minminsun Can you help to review this?

Sure, and we are refactoring the auto tc codegen code to reuse these intrinsics.

@Laurawly
Copy link
Contributor

Laurawly commented Oct 16, 2019

@Laurawly Thank you for your comments. It is my negligence forgetting to support int type. But one thing worths noting is that int type TensorCores are only supported on Turing GPUs. I can't access any Turing GPUs. It should work but need more testing.

As for the performance. Currently, we are faster than CUBLAS/CUDNN on small shapes but slower on large shapes. In fact, the performance depends on the scheduling but not the tvm implement. I believe that I have the same performance as Minmin's. Even the CUDA code is the same. I have to admit there should be a better schedule on specific shapes. Any better schedule is welcomed. However, I believe that it should be done by autotvm rather than human beings.

Volta GPUs have int support as well. I tested int8 GEMM using wmma instructions on a V100 GPU. It has even better support than Turing architecture. You can also find some info here: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf.

Performance-wise if you believe auto tvm could give us better results than cublas, could you expose the autotvm interface for people to do auto search?

namespace tvm {
namespace ir {

class FragmentGetter : public IRVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add doxygen-style comments to the classes, methods, and members defined in this file?

echo "- python3 tests/lint/add_asf_header.py file_list.txt"
exit 1
fi
#echo "Check ASF license header..."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncomment?

# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync`
# are both used in matrix multiplication, so we can just write following three intrinsics.

def intrin_wmma_load_matrix(scope):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These intrinsics look copied from the test - is there a place we can put them so they're easily accessible in tutorials, tests, and at runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they are general enough. They are the same in tutorials and tests only because I use the same memory layout. There may be so many parameters in these intrinsics if I make them general. The best way I think is that users build their intrinsic in different situations.

@Hzfengsy
Copy link
Member Author

Volta GPUs have int support as well. I tested int8 GEMM using wmma instructions on a V100 GPU. It has even better support than Turing architecture. You can also find some info here: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf.

I still think Volta GPUs have no such support. This slide from your links may say the same thing. Can you doublecheck it? Or can you share your code and show me how to use it on Volta GPU? Thanks a lot!
image

Performance-wise if you believe auto tvm could give us better results than cublas, could you expose the autotvm interface for people to do auto search?

The autotvm interface is always here. Users can use autotvm to tune TensorCore schedules as they wise. I guess your opinion is adding an auto turning tutorial?

@Laurawly
Copy link
Contributor

Volta GPUs have int support as well. I tested int8 GEMM using wmma instructions on a V100 GPU. It has even better support than Turing architecture. You can also find some info here: https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf.

I still think Volta GPUs have no such support. This slide from your links may say the same thing. Can you doublecheck it? Or can you share your code and show me how to use it on Volta GPU? Thanks a lot!
image

Performance-wise if you believe auto tvm could give us better results than cublas, could you expose the autotvm interface for people to do auto search?

The autotvm interface is always here. Users can use autotvm to tune TensorCore schedules as they wise. I guess your opinion is adding an auto turning tutorial?

Yes, I misread the document and when I did nvprof on my code on V100 it seems that it didn't go though the int8 wmma kernel but it did on Turing GPU.

If users run your test/tutorial, can they reproduce the performance you mentioned? If so, then I think it would be helpful to have an autotvm schedule as well for users to search performance for larger shapes, but it's ok to put them in another PR.

@Hzfengsy
Copy link
Member Author

If users run your test/tutorial, can they reproduce the performance you mentioned? If so, then I think it would be helpful to have an autotvm schedule as well for users to search performance for larger shapes, but it's ok to put them in another PR.

Yes, they can easily reproduce the performance. I agree with you that we should have an autotvm tutorial. @minminsun is working on it. It will be better to have a tutorial after both of our solutions are merged so that we can choose a better way to use autotvm.

@Laurawly
Copy link
Contributor

@were @merrymercy Can you help to review this PR as well?

@Hzfengsy Hzfengsy force-pushed the master branch 2 times, most recently from 02cb8a5 to 198e786 Compare October 20, 2019 22:45
@Laurawly Laurawly added the status: need update need update based on feedbacks label Oct 21, 2019
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float16')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the 16 x 16 x 16 tensor core combination, have you also tested the 32 x 8 x 16 and 8 x 32 x 16 combinations for m, n, k?

@Laurawly
Copy link
Contributor

@soiferj @minminsun @vinx13 Can you update on your reviews?

@Hzfengsy
Copy link
Member Author

@Laurawly @minminsun @vinx13 @soiferj Can please you have another look at it?

@Laurawly
Copy link
Contributor

Overall LGTM, let's wait for other reviewers' updates as well.

Copy link
Contributor

@minminsun minminsun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me too.

@Laurawly Laurawly merged commit 324a960 into apache:master Oct 24, 2019
@Laurawly
Copy link
Contributor

Thanks @Hzfengsy @vinx13 @minminsun @soiferj

kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Oct 30, 2019
* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic
kevinthesun added a commit to neo-ai/tvm that referenced this pull request Oct 31, 2019
* [relay][vm] Separate VM runtime with executable (apache#4100)

* [relay][vm] Separate VM runtime with executable

* Address comments

* move ctx back to vm

* make only vm related fields and methods protected

* integrate seriliaztion/deserialization to executable

* create stream

* [Relay][Frontend][TF] Add tensor array ops (apache#3798)

* [Relay][Frontend][TF] Add tensor array ops

* rename

* delete test

* Move utility function

* Refactor

* fix tensor array ops

* fix test

* fix rebase

* Fix serializer bug

* Improve tf convert name lookup to use prelude api

* Fix lint

* Fix test

* Fix typo (apache#4144)

* [CI] Pin NNPack pthreadtools version (apache#4152)

* [QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2. (apache#4142)

* Add lift_if_then_else pass (apache#3865)

* Add LiftIfThenElse pass

* Add more comments

* Rename and refactor

* Add description for internal data structure

* Rename a test

* Minor change

* Address comments

* Improve update_for

* [CI] Update cpu docker (apache#4153)

* [Refactor] Rename Datatype to ADT (apache#4156)

We think it will reduce the confusion with the meaning.

https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339

* [Runtime] Enable option to use OpenMP thread pool (apache#4089)

* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (apache#4161)

* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol.

This PR removes the original node system, and make node as a subclass of Object.
This is a major refactor towards a better unified runtime object system.

List of changes in the refactor:

- We now hide data_ field, use Downcast explicitly to get a sub-class object.
- Removed the node system FFI in python.
- Removed the node C API, instead use PackedFunc for list and get attrs.
- Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type<AttrType>().
  - This change was necessary because of the new Object registration mechanism.
  - Subsequent changes to the op registrations
  - The change revealed a few previous problems that is now fixed.
- Patched up a few missing node type registration.
  - Now we will raise an error if we register object that is not registered.
- The original node.h and container.h are kept in the same location.
- Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed.
- IRFunctor now dispatches on ObjectRef.
- Update to the new type checking API: is_type, derived_from are replaced by IsInstance.
- Removed .hash member function, instead use C++ convention hasher functors.

* Address review comments

* [CI] Move golang tests to the end (apache#4164)

* Add support for quantized multiply to Relay (apache#4141)

This patch adds multiply operator for quantized tensors.
The details of the quantized multiplication are outlined
in the code.

This builds on pull request 3927 and includes the changes
Animesh mentions in the comments on that request.

Change-Id: I555715b53d0266a91d5c03dc3dfe8fc31e7ce4e1

* Fix missspelling (apache#4166)

FIX "After connecting he usb" with "After connecting the usb"

* [Relay][Pass] Count MAC for BatchMatMul (apache#4157)

* count MAC for BatchMatMul

* update doc

* [Relay][QNN] Add unit test for int8 (apache#4159)

* [bugfix][codegen] fix casting bug in llvm codegen

* update example

* retrigger ci

* check llvm version

* [relay][vm] Reuse allocated device memory (apache#4170)

* add missing gradient check to gradient pass (apache#4169)

* merge extract_from_program and extract_from_multiple_progam (apache#4173)

* [TOPI] Added support for Mali Bifrost target (apache#4047)

* [Relay][Frontend][TF] Fix Size operator (apache#4175)

* [Relay][Frontend][TF] Fix Size operator

* Uncomment tests

* [Pass] Remove dead code (apache#4177)

* [rpc] use callback func to do send & recv (apache#4147)

* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java

* fix java build

* fix min/max macro define in windows

* keep the old rpc setup for py

* add doc for CallbackChannel

* Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend. (apache#4172)

* [DOCS] Add TensorFlow frontend docs (apache#4154)

* Start to update TF frontend docs

* Add rst

* Remove markdown

* Update wording

* Resolve comments

* Revert "[Relay][QNN] Add unit test for int8 (apache#4159)" (apache#4192)

This reverts commit 6f9d028.

* [cmake][ANTLR] Support setting path to ANTLR jar (apache#4176)

* Support setting path to ANTLR jar

* Update comment

* Split adaptive_pool2d_avg into sum and div (apache#4186)

* [Documentation]Fix example code in comment of tvm.build_module.build() (apache#4195)

* Fix example code in comment of tvm.build_module.build()

* Update build_module.py

* [relay] use time_evaluator for measurement (apache#4191)

* Add parser support for SUM tflite operator (apache#4182)

* [Relay] Fix memory leak in the interpreter (apache#4155)

* save

lint

* address reviewer comment

* [TOPI] Tunable Template for Conv2D HWCN on CUDA (apache#4168)

* support conv2d HWCN in AutoTVM and Relay

* fix lint

* fix comments and unit tests

* TensorCore Support using Intrinsic (apache#4136)

* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic

* [NODE][REFACTOR] Refactor reflection system in node. (apache#4189)

* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments

* hotfix the ci (apache#4199)

* [TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI instructions. (apache#4196)

* [Relay] crossentropy_with_logits and its gradient (apache#4075)

* save

* lint

* [hotfix] missing include headers (apache#4204)

* [Relay][Training] Add checkpoint annotation for checkpointing memory optimization (apache#4146)

* add checkpoint annotation for checkpointing memory optimization

* add alpha-equivalence checkpoint test and fix gradient type issue

* fix build issues

* ignore checkpoint annotation when checking missing gradients

* refactor, fix checkpoint compute for tuple and add tests

* [Relay][Params] Add APIs for storing and retrieving parameters from individual functions. (apache#4194)

* Add support for attaching params

* Fix types

* Fix test

* [Relay][Frontend][ONNX] Add support for op Where (apache#4184)

* Add support for op Where

* Update impl version

* [VTA][Chisel] TSIM VTA Source Refactor (apache#4163)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo

* [RUNTIME] Separate runtime related contrib into runtime/contrib (apache#4207)

* Fix type var docs (apache#4208)

* [Relay] Setting Legalize opt_level to 1. (apache#4198)

* [TOPI] Fix flaky testcase for check round (apache#4211)

* [Relay][Op] Enhance Upsample Operator to support float scales   (apache#4206)

* :add scale2 for upsample

* update unit test for upsampling

* support latest upsample op for multiple frontend

* fix lint

* fix lint

* fix lint

* fix lint

* update scale description and rebase

* [Relay][Quantize] Use fixed point mulplications (apache#4160)

* Update have_int8 condition to run on compute capability 7.x devices (apache#4214)

* Optimizing autotvm task extraction speed (apache#4138)

* Optimize task extraction speed

* correct pylint errors

* Delete unused function

* remove unnecessary argument

* resolve code review comments

* corrent cpp lint errors

* remove one more graph_json return value

* fix test bugs

* [Relay] Add Python type functor and tests (apache#4209)

* Add Python type functor and tests

* Lint roller

* Fix typo in packed_func.h (apache#4219)

* Improve the lowering of Qnn Dense (apache#4213)

* [QNN] Improving Dense lowering.

* - Moving get_shape method to util
- Finalizing the test cases and the code structure for optimized dense computation.

* - Fixing cpplint.

* - Addressing review comments.

* - Renaming the variables correctly.

* - Renaming the variables correctly.

* [ARITH] Fix the rule y < x && x <= y (apache#4220)

* [PYTHON] Add __init__ to the generated grammar so that it can be installed properly (apache#4223)

* [Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (apache#4197)

* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.

* [Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (apache#4205)

* Add support for Any op

* Support ONNX frontend

* Add doc

* Add to relay docs

* Dummy change to retrigger CI

*  Update dmlc_tvm_commit_id.txt

* Merge from upstream
tqchen pushed a commit to tqchen/tvm that referenced this pull request Mar 29, 2020
* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need review status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants