-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vectorize size calculation #2035
Conversation
This reverts commit 31e8d0a.
@@ -26125,6 +26125,160 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) { | |||
testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__); | |||
} | |||
|
|||
namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should split this file. It is too long that GitHub no longer displays syntax highlights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is anyone against with just naming test files like test_gpu1.cpp
, test_gpu2.cpp
, etc? We should continue have topic-specific test files like test_gpu_view.cpp
, but there are many that don't belong to one specific category. An easy way is to rename test_gpu.cpp
to test_gpu1.cpp
and don't add new tests there anymore. New tests should go into test_gpu2.cpp
until that file becomes too large, then test_gpu3.cpp
should be created. Not super neat, but I think it should be just enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks!
@@ -202,8 +190,16 @@ size_t expandVectorizationToContigMergedDomains( | |||
|
|||
// Merge the domains right of the break point | |||
const auto& ref_root = reference_tv->getMaybeRFactorDomain(); | |||
const int num_merged_domains = | |||
const int max_num_merged_domains = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to have to relook at how I reimplemented this. I think it should be safe, but will have to pull in the new tests and check obviously.
Relevant changes I made:
https://github.com/csarofeen/pytorch/pull/2039/files#diff-cb1f12c937ff579f33d244faf2152ecf9a4a28a8ca7203863bcf158de149f8eeR287
* Fix vectorize size calculation (#2035) * Allow non-root trivial reductions (#2037) * Allow non-root trivial reductions Fixes #2008 Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com> * Test file cleanup (#2040) * Move test_gpu.cpp to test_gpu1.cpp * Split test_gpu1.cpp to test_gpu1.cpp, test_gpu2.cpp and test_gpu3.cpp. Each file should be up to 10K LoC. New tests should be added to test_gpu3.cpp until it gets 10K LoC. Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com> Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Codegen changes include: * codegen improvement: i. allow non-root trivial reductions, allow empty/no-op fusion ii. fixes vectorization checks and size calculation iii. bank conflict handle improvement iv. enables transpose scheduler * misc: i. CI tests failure fixes ii. cpp tests file clean up iii. trivial forwarding supports added in codegen runtime iv. added factory methods support in codegen Commits that's in this PR from the devel branch: ``` 7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048) 65af1a4 Inserting sync for redundant parallel types is already done at the (#2023) 6ac74d1 Fix sync map (#2047) f5bca33 Bank conflict checker improvements (#2032) d2ca7e3 Minor update on cp.async code generation. (#1901) d36cf61 Test file cleanup (#2040) 0b8e83f Allow non-root trivial reductions (#2037) a2dfe40 Fix vectorize size calculation (#2035) e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025) 197221b removing ci workflow (#2034) 40e2703 Reduction rand like patch (#2031) bc77266 Add utility for checking bank conflict of shared memory (#2029) ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030) fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024) bca20c1 Cleanup trivial reduction workarounds (#2006) e4b6585 Trivial forwarding (#1995) 1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991) a4effa6 Enable output allocation cache (#2010) 35440b7 Patching bn inference (#2016) 0f9f0b4 Add matmul benchmark (#2007) 45045cd Enable tests previously disabled due to an aliasing bug (#2005) 967aa77 Contiguous indexing for View operations (#1990) a43cb20 Make inlining even more modular (#2004) dc45835 Test util cleanup (#2003) 3ca21eb More strict validation (#2000) a7a7d57 Fix build problem (#1999) fc235b0 Just fixes comments (#1998) 482386c cleanup (#1997) 4cbe0db Improve divisible split detection (#1970) 42ccc52 Minor build fix. (#1996) fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989) 15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988) 8f1c7f5 Minor cleanup lower_unroll.cpp (#1994) 1d9858c Minor cleanup (#1992) f262d9c Add support for uniform RNG (#1986) eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987) 634820c Add support for some empty fusion (#1981) eabe8d8 Segment self mapping fusions (#1954) e96aacf Enable Transpose operation (#1882) 425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835) 306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969) b1bd32c Minor fix (#1967) bd93578 Enable transpose scheduler (#1927) b7a206e Move scheduler vectorize utilities into their own file (#1959) d9420e4 View scheduling (#1928) c668e13 Upstream push ci fixes (#1965) c40202b Fix dump effective bandwidth (#1962) 93505bc WAR on index mapping when exact and permissive maps differ (#1960) 45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930) a3ecb33 Improve the comments at the beginning of index_compute.h (#1946) f7bc341 Remove unused variables (#1955) df3393a Some cleanup (#1957) 7d1d7c8 TVDomainGuard factory (#1953) 357ba22 Fill allocation with nan on tests (#1956) 8eafc54 Fix detection of unmappable root domains (#1952) 90a51f2 Some indexing cleanups, Add eye support (#1940) ddc01e4 Exclude unsupported data types (#1951) 992e17c test the groups the same order as they are merged (#1949) 208262b Move detection of self mapping IDs to IterDomainGraph from (#1941) ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828 6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943) aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD 4c254c0 Fix arange when step is negative (#1942) 89330aa Tensor factories must set the output shape as its input (#1939) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846) Pull Request resolved: pytorch#87779 Approved by: https://github.com/davidberard98
Fixes #1843