Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metaschedule] Auto tensorization for CPU / GPU dot product #11088

Merged
merged 20 commits into from
Apr 26, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Apr 21, 2022

Building on #11075, add MultiLevelTilingWithIntrin schedule rule and RewriteTensorize postproc, which can be used for auto-tensorization with a single intrinsic, such as CPU / GPU dot product. This is the simplistic but non-trivial use of auto tensorization.

The diff looks large but most of them are boilerplate from tests. The actual change to enable auto tensorization is about 300 lines.

MultiLevelTilingWithIntrin can be used to auto-tensorize schedules with the following intrinsics. We should be able to deprecate corresponding manual templates in AutoTVM, but detail perf analysis is yet to be done.

  • VNNI conv2d / dense
  • ARM NCHWc conv2d (with or without sdot) (cc @tkonolige)
  • dp4a for cuda, SPIRV integer dot product for vulkan, and AMDGPU gfx10 sdot4 for rocm.

As a demonstration, I've add integration tests in tests/python/integration/test_meta_schedule_auto_tensorize.py, one of which is E2E auto-tensorzation on quantized bert-base x {VNNI, DP4A}. DP4A tests can also run on AMDGPU via vulkan or rocm backends (@mei-ye @tmoreau89).

Co-authored-by: Siyuan Feng Hzfengsy@sjtu.edu.cn
Co-authored-by: Bohan Hou 32121147+spectrometerHBH@users.noreply.github.com
Co-authored-by: Hongyi Jin 3231950289@qq.com
Co-authored-by: Ruihang Lai lairuihangdongdong@qq.com
Co-authored-by: Wuwei Lin wuwei@apache.org

@junrushao1994 @vinx13 @comaniac @mbrookhart @spectrometerHBH @Hzfengsy @MasterJH5574 @jinhongyii

if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DecomposeReduction applied before this postproc copies meta_schedule_auto_tensorize attributes to the init block as well. So we need to make sure that we won't try to tensorize a block even if it has meta_schedule_auto_tensorize annotation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are target-specific handling here, ideally we can make the init block behavior configurable in meta schedule rule, it is fine for now

ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to above, since DecomposeReduction introduces a new loop that should be vectorized on CPU, for now I'm applying vecotorization to the decomposed init loop here. This can also be done in RewriteReductionBlock.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does postproc::RewriteParallelVectorizeUnroll for this case?

Copy link
Member Author

@masahi masahi Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it would, but it doesn't. Also since parallelization etc is supposed to be applied before DecomposeReduction, I don't think running RewriteParallelVectorizeUnroll after RewriteReductionBlock() is a good idea. So vectorization of the init loop has to be done manually somehow.

I'd prefer vectoring in the init loop right after we run DecomposeReduction during RewriteReductionBlock, since vecotorization of the init loop should be done on CPU regardless of tensorization. cc @MasterJH5574

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! What’s the order of post-processors being applied now? Perhaps we should reflect this order by adding this post-processor to tune.py

@staticmethod
def _postproc() -> List[Postproc]:
from tvm.meta_schedule import postproc as M
return [
M.DisallowDynamicLoop(),
M.RewriteCooperativeFetch(),
M.RewriteUnboundBlock(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.VerifyGPUCode(),
]

Copy link
Member Author

@masahi masahi Apr 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue in question is vectorization for CPU targets. I'm using the default postprocs in

def _postproc() -> List[Postproc]:
from tvm.meta_schedule import postproc as M
return [
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
]

Since loop parallelization or vectorization checks for the "compact dataflow" constraint,

CheckSubtreeCompactDataflow(self, loop_sref);
, they need to be applied before DecomposeReduction in RewriteReductionBlock(). So having RewriteParallelVectorizeUnroll before RewriteReductionBlock() in the default postprocs makes sense.

However, this is not sufficient to vectorize the init loop of reduction block, since it is generated during RewriteReductionBlock(). I don't think we should run RewriteParallelVectorizeUnroll again after RewriteReductionBlock() (and it doesn't work anyway), so we need to manually vectorize the decomposed init loop in RewriteReductionBlock or the new RewriteTensorize postproc I added. I prefer the former.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I want to tensorize the reduction block. So before DecomposeReduction is called, the loop kind of the reduction is serial, which makes the decomposed init loop be serial as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So the block we want to tensorize wasn’t applied by the schedule rule ParallelVectorizeUnroll as well 🤔?

Copy link
Member Author

@masahi masahi Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes (otherwise tensorize pattern matching fails, because an intrin desc is always serial), I'm not exactly sure what prevents ParallelVectorizeUnroll from tampering the block we want to tensorize (which is a good thing), maybe Blockize I do at

tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value());

(after tiling the inner loop nests to be tensorized) is helping?

Copy link
Contributor

@MasterJH5574 MasterJH5574 Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite interesting.. So here the case is, on one hand we don’t want the block being annotated by rule ParallelVectorizeUnroll, but on the other hand we do want its init block to be vectorized after the decomposition. Am I right?

Since before decomposition the block wasn’t annotated by ParallelVectorizeUnroll, the decomposed init block isn’t vectorized, which makes sense. In addition, the decomposed init block doesn’t have any information to indicate that it’s supposed to vectorized (e.g., it doesn’t have an “need vectorization” annotation). In this case, no matter we vectorize the init block loop in RewriteReductionBlock or RewriteTensorize, it’s all due to our human knowledge, which I don’t think is perfect.

For upstreaming, it might be okay to do manual vectorization in RewriteTensorize (how does the vectorization in RewriteTensorize bypass the compact dataflow issue BTW?). But in the long term I suppose we should enhance the compact dataflow check to allow such vectorization. After all, such vectorization won’t incur any incorrectness.

cc @junrushao1994 @spectrometerHBH

Copy link
Member Author

@masahi masahi Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite interesting.. So here the case is, on one hand we don’t want the block being annotated by rule ParallelVectorizeUnroll, but on the other hand we do want its init block to be vectorized after the decomposition. Am I right?

Exactly.

how does the vectorization in RewriteTensorize bypass the compact dataflow issue BTW?

That's a great question! Until recently, vectorization of the init loop after DecomposeReduction was rejected by the compact dataflow check. I brought this topic to @Hzfengsy and the team came up with a relaxation of the constraint that allows vectorizing init loop. This is the PR #10705

Yeah, the ideally all outer loop parallelizations and inner loop vectorization can be done by one pass of ParallelVectorizeUnroll, meaning we run it after DecomposeReduction. Currently outer loop parallelization after DecomposeReduction would be rejected by the compact dataflow check, but I think this is still too restrictive.

@junrushao
Copy link
Member

I'm super excited to see this PR!! Would love to have some helping hands review this PR :-) CC: @vinx13 @spectrometerHBH

@masahi
Copy link
Member Author

masahi commented Apr 22, 2022

Some perf numbers on int8 bert-base:

VNNI, rocketlake 6 core

 ID |                                                  Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
-----------------------------------------------------------------------------------------------------------------------------------------------------------
-----
  0 |   fused_nn_batch_matmul_multiply_expand_dims_subtract |  228266496 |     12 |      2570.6195 |      88.7982 |             1065.5789 |    256 |
  1 | fused_nn_batch_matmul_multiply_expand_dims_subtract_1 |  226788096 |     12 |      2354.2875 |      96.3298 |             1155.9579 |    256 |
  2 |                  fused_nn_contrib_dense_pack_subtract |  453279744 |     48 |      2630.4608 |     172.3195 |             8271.3371 |    256 | Y
  3 |                fused_nn_contrib_dense_pack_subtract_1 | 1813118976 |     12 |      2773.5020 |     653.7291 |              7844.7493 |    256 |Y
  4 |                fused_nn_contrib_dense_pack_subtract_2 | 1812234240 |     12 |      2775.5088 |     652.9377 |             7835.2520 |    256 | Y
----------------------------------------------------------------------------------------------------------------------------------------------------------------

RTX 3070 with DP4A (FP32 peak around 16 TFLOPS)

 ID |                    Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated                                                             ----------------------------------------------------------------------------------------------------------------------------------
  0 |   fused_nn_batch_matmul |  226492416 |     12 |     10978.4995 |      20.6305 |              247.5665 |    512 |          Y
  1 | fused_nn_batch_matmul_1 |  226492416 |     12 |     14038.9348 |      16.1332 |              193.5979 |    512 |
  2 |          fused_nn_dense |  452984832 |     48 |     17875.0444 |      25.3417 |             1216.4038 |    512 |          Y
  3 |        fused_nn_dense_1 | 1811939328 |     12 |     25448.8947 |      71.1991 |              854.3896 |    512 |          Y
  4 |        fused_nn_dense_2 | 1811939328 |     12 |     21945.3012 |      82.5662 |              990.7940 |    512 |          Y
----------------------------------------------------------------------------------------------------------------------------------                          

AMDGPU RX6600xt with DP4A (FP32 peak around 10 TFLOPS)

 ID |                    Name |       FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated
 ----------------------------------------------------------------------------------------------------------------------------------
  0 |   fused_nn_batch_matmul |  226492416 |     12 |     10589.5889 |      21.3882 |              256.6586 |    512 |
  1 | fused_nn_batch_matmul_1 |  226492416 |     12 |      9998.6694 |      22.6523 |              271.8271 |    512 |          Y
  2 |          fused_nn_dense |  452984832 |     48 |     13374.8473 |      33.8684 |             1625.6837 |    512 |          Y
  3 |        fused_nn_dense_1 | 1811939328 |     12 |     13873.1209 |     130.6079 |             1567.2949 |    512 |          Y
  4 |        fused_nn_dense_2 | 1811939328 |     12 |     17295.8264 |     104.7617 |             1257.1398 |    512 |          Y
----------------------------------------------------------------------------------------------------------------------------------

Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts! Excited to see auto-tensorization happening!

Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the list of post-processors here as well?

class Postproc : public runtime::ObjectRef {

@masahi masahi force-pushed the auto-tensorize-dot branch 3 times, most recently from 6d6c3b4 to e104593 Compare April 22, 2022 21:13
include/tvm/meta_schedule/postproc.h Outdated Show resolved Hide resolved
@vinx13 vinx13 merged commit 6846484 into apache:main Apr 26, 2022
shtinsa pushed a commit to Deelvin/tvm that referenced this pull request May 17, 2022
…1088)

* [Metaschedule] Auto-tensorization for CPU / GPU dot product

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

* doc update

* add vnni conv2d test

* add dp4a test

* adding tests for rewrite_tensorize

* add rewrite_tensorize test

* add missing pydoc

* black

* more doc

* adding auto tensorize integration test

* add dp4a test

* fix target name

* fix dtype in test

* skip bert test

* replace hard-coded llvm intrinsic id in test with look up

* remove unnecessary include, add doc for the rest of params

* update postproc.h

* update doc

* fix shape in te matmul workload

* fix newline in cppdoc

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants