forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Silent wrong result for RNG + transpose non-square tile #1926
Comments
cc: @csarofeen |
Looks like T6 should not share the same predicate with other tensors |
Minimum repro: TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
TensorView* tv0 = makeSymbolicTensor(2);
fusion->addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion->addOutput(tv2);
tv1->setMemoryType(MemoryType::Shared);
for (auto tv : {tv1, tv2}) {
tv->split(0, 4);
tv->reorder({{1, -1}});
tv->split(1, 8);
tv->merge(0);
tv->split(0, 1);
tv->axis(0)->parallelize(ParallelType::BIDx);
tv->axis(1)->parallelize(ParallelType::Unswitch);
}
tv1->merge(2);
tv2->reorder({{2, 3}});
tv2->merge(2);
for (auto tv : {tv1, tv2}) {
tv->axis(-1)->parallelize(ParallelType::TIDx);
}
InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({5, 5}, options);
FusionExecutor fe;
fe.compileFusion(fusion, {t0});
auto cg_outputs = fe.runFusion({t0});
auto out = cg_outputs[0];
testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__);
} |
The following is a quick hacking way to "fix" the issue: bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) {
return false;
} Will take a deeper look to find a better |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🐛 Describe the bug
Output:
Fusion:
CUDA:
Versions
TOT devel
The text was updated successfully, but these errors were encountered: