Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable tests previously disabled due to an aliasing bug #2005

Merged
merged 2 commits into from
Sep 28, 2022

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Sep 28, 2022

The bug was fixed by #1792

Both tests just require about 33 KB of SMEM, not 70 KB anymore, so they should be fine on most devices. Still making sure not exceeding the available capacity.

Comment on lines +9713 to +9714
const size_t required_smem_size =
(dimy - static_size) * sizeof(float) + TIDX * sizeof(float);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary in this PR, but it would be good to use the inferred shared memory size from FusionExecutor::computeSharedMemory, instead of manually calculating it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's part of the check actually. We want to make sure that the aliasing is accurately analyzed and only this size is required. If aliasing analysis fails somehow, FusionExecutor::computeSharedMemory could just return a larger size, and the test would be just skipped, but in that case we would want the test to fail. So in this case checking with this required_smem_size is more appropriate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense

auto properties = at::cuda::getDeviceProperties(0);
const size_t required_smem_size =
(dimy - static_size) * sizeof(float) + TIDX * sizeof(float);
if (properties->sharedMemPerBlock < required_smem_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question, what is the difference between sharedMemPerBlock and sharedMemPerBlockOptin? I didn't find the official documentation helpful:

size_t cudaDeviceProp::sharedMemPerBlock [inherited]
  Shared memory available per block in bytes

size_t cudaDeviceProp::sharedMemPerBlockOptin [inherited]
  Per device maximum shared memory per block usable by special opt in

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check thoroughly, but I assume the former is the default size, whereas the latter is the maximum size that can be configured. Since we don't do any smem size configuration, I think the former is the right limit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it related to this part of the code?

// Check that requested smem size can be dynamically allocated.
// This check is only done once a kernel has been compiled, since
// maybe_available_dynamic_smem_ needs to be evaluated on
// a compiled kernel.
if (maybe_available_dynamic_smem_.has_value()) {
// Dynamic shared memory space that we can allocate without
// carving more space from L1.
const uint64_t available_dynamic_smem_without_reconfiguration =
maybe_available_dynamic_smem_.value();
// Maximum additional shared memory size we could request
// if we do re-configuration.
const uint64_t additional_dynamic_smem_available_through_reconfiguration =
device_smem_limit_ - configured_device_smem_;
TORCH_INTERNAL_ASSERT(
(dynamic_smem_size) <
(available_dynamic_smem_without_reconfiguration +
additional_dynamic_smem_available_through_reconfiguration),
"The total shared memory allocation is larger than available memory.",
" Dynamic size: ",
dynamic_smem_size,
". Available size: ",
maybe_available_dynamic_smem_.value(),
". Configured smem size: ",
configured_device_smem_,
". Device limit size: ",
device_smem_limit_);
}
launch_params.setSmem(dynamic_smem_size);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't know we actually do configure the size when necessary.

// Increase limit of dynamic shared memory if needed.

I'll change the check to use Optin.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, those values on Titan RTX are:

sharedMemPerBlock: 49152 
sharedMemPerBlockOptin: 65536

@naoyam naoyam merged commit 45045cd into devel Sep 28, 2022
@naoyam naoyam deleted the reenable_disabled_tests branch September 28, 2022 23:58
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants