Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Transpose operation #1882

Merged
merged 8 commits into from
Sep 13, 2022
Merged

Enable Transpose operation #1882

merged 8 commits into from
Sep 13, 2022

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Jul 31, 2022

Enable permute, transpose, and t in python frontend

Requires: Transpose Scheduler

@zasdfgbnm zasdfgbnm mentioned this pull request Aug 12, 2022
7 tasks
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@csarofeen
Copy link
Owner

@zasdfgbnm are we ready for this?

@rdspring1
Copy link
Collaborator Author

Note: This PR only enables transpose for Torchscript. We'll another PR for primTorch.

@zasdfgbnm
Copy link
Collaborator

Seeing a segfault at test_nvfuser_correctness_permute_cuda_bool looking into it

gdb shows the following stack trace, I am not sure if it is related

#0  0x00007fffc2c6ef91 in __cxxabiv1::__cxa_throw (obj=0x5556040f1d10, tinfo=0x7fffc167d0b8 <typeinfo for c10::Error>, 
    dest=0x7fffd7929030 <c10::Error::~Error()>) at /usr/src/debug/gcc/libstdc++-v3/libsupc++/eh_throw.cc:81
#1  0x00007fffd7975e9f in c10::detail::torchCheckFail (func=0x7fff94e0b15f "getDebugState", 
    file=0x7fff94e0a6a5 "/home/gaoxiang/nvfuser/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp", line=698, 
    msg=0x7fff94e0b16d "optimized_plan_ INTERNAL ASSERT FAILED at \"/home/gaoxiang/nvfuser/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp\":698, please report a bug to PyTorch. ") at /home/gaoxiang/nvfuser2/c10/util/Exception.cpp:94
#2  0x00007fffc007375b in c10::detail::torchInternalAssertFail (func=0x7fff94e0b15f "getDebugState", 
    file=0x7fff94e0a6a5 "/home/gaoxiang/nvfuser/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp", line=698, 
    condMsg=0x7fff94e0b16d "optimized_plan_ INTERNAL ASSERT FAILED at \"/home/gaoxiang/nvfuser/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp\":698, please report a bug to PyTorch. ") at /home/gaoxiang/nvfuser/c10/util/Exception.h:441
#3  0x00007fff8f8a0bcf in torch::jit::ProfilingGraphExecutorImpl::getDebugState (this=0x5556037e08f0)
    at /home/gaoxiang/nvfuser/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp:698
#4  0x00007fff8f8508c1 in torch::jit::GraphExecutor::getDebugState (this=0x55560071fbf0)
    at /home/gaoxiang/nvfuser/torch/csrc/jit/runtime/graph_executor.cpp:842
#5  0x00007fffc09eafc0 in torch::jit::initJitScriptBindings(_object*)::$_59::operator()(torch::jit::StrongFunctionPtr const&) const (this=0x55555905cfb8, self=...) at /home/gaoxiang/nvfuser/torch/csrc/jit/python/script_init.cpp:1500
#6  0x00007fffc09eaf61 in pybind11::detail::argument_loader<torch::jit::StrongFunctionPtr const&>::call_impl<torch::jit::GraphExecutorState, torch::jit::initJitScriptBindings(_object*)::$_59&, 0ul, pybind11::detail::void_type>(torch::jit::initJitScriptBindings(_object*)::$_59&, std::integer_sequence<unsigned long, 0ul>, pybind11::detail::void_type&&) && (this=0x7fffffff9968, 
    f=...) at /home/gaoxiang/nvfuser/cmake/../third_party/pybind11/include/pybind11/cast.h:1441
#7  0x00007fffc09eaed6 in pybind11::detail::argument_loader<torch::jit::StrongFunctionPtr const&>::call<torch::jit::GraphExecutorState, pybind11::detail::void_type, torch::jit::initJitScriptBindings(_object*)::$_59&>(torch::jit::initJitScriptBindings(_object*)::$_59&) && (this=0x7fffffff9968, f=...)
    at /home/gaoxiang/nvfuser/cmake/../third_party/pybind11/include/pybind11/cast.h:1409
#8  0x00007fffc09ead98 in pybind11::cpp_function::initialize<torch::jit::initJitScriptBindings(_object*)::$_59, torch::jit::GraphExecutorState, torch::jit::StrongFunctionPtr const&, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::initJitScriptBindings(_object*)::$_59&&, torch::jit::GraphExecutorState (*)(torch::jit::StrongFunctionPtr const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#1}::operator()(pybind11::detail::function_call&) const (this=0x7fffffff9fd0, call=...)
    at /home/gaoxiang/nvfuser/cmake/../third_party/pybind11/include/pybind11/pybind11.h:249
#9  0x00007fffc09eac92 in pybind11::cpp_function::initialize<torch::jit::initJitScriptBindings(_object*)::$_59, torch::jit::GraphExecutorState, torch::jit::StrongFunctionPtr const&, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::initJitScriptBindings(_object*)::$_59&&, torch::jit::GraphExecutorState (*)(torch::jit::StrongFunctionPtr const&), pybind11::name cons--Type <RET> for more, q to quit, c to continue without paging--
t&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#1}::__invoke(pybind11::detail::function_call&) (call=...) at /home/gaoxiang/nvfuser/cmake/../third_party/pybind11/include/pybind11/pybind11.h:224
#10 0x00007fffbfe1ef07 in pybind11::cpp_function::dispatcher (self=0x7ffff6613660, args_in=0x7fff10918520, kwargs_in=0x0)
    at /home/gaoxiang/nvfuser/cmake/../third_party/pybind11/include/pybind11/pybind11.h:929

@rdspring1
Copy link
Collaborator Author

Hmm. I am seeing different errors between runs of the test.

RuntimeError: expanded_extent_ == nullptr || isBroadcast() 
INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/ir_internal_nodes.h":879, please report a bug to PyTorch.
Expanded extent is only relevant for strided broadcast dimensions yet found an expanded extent 
without a strided broadcast iter type.
RuntimeError: extent_ != nullptr INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/ir_internal_nodes.h":874, 
please report a bug to PyTorch.

SegFault:

0x00007fffaeac9bda in torch::jit::fuser::cuda::IterDomainBuilder::IterDomainBuilder(torch::jit::fuser::cuda::IterDomain const*) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
(gdb) bt
#0  0x00007fffaeac9bda in torch::jit::fuser::cuda::IterDomainBuilder::IterDomainBuilder(torch::jit::fuser::cuda::IterDomain const*) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#1  0x00007fffaeacafc7 in torch::jit::fuser::cuda::IterDomain::cloneWithoutRFactor() const () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#2  0x00007fffaeb91dad in torch::jit::fuser::cuda::permute(torch::jit::fuser::cuda::TensorView*, std::vector<long, std::allocator<long> > const&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#3  0x00007fffaebb8187 in torch::jit::fuser::cuda::(anonymous namespace)::IrParser::registerJitOperator()::{lambda(torch::jit::Node const*, std::unordered_map<unsigned long, torch::jit::fuser::cuda::(anonymous namespace)::ValueHolder, std::hash<unsigned long>, std::equal_to<unsigned long>, std::allocator<std::pair<unsigned long const, torch::jit::fuser::cuda::(anonymous namespace)::ValueHolder> > >&)#89}::operator()(torch::jit::Node const*, std::unordered_map<unsigned long, torch::jit::fuser::cuda::(anonymous namespace)::ValueHolder, std::hash<unsigned long>, std::equal_to<unsigned long>, std::allocator<std::pair<unsigned long const, torch::jit::fuser::cuda::(anonymous namespace)::ValueHolder> > >&) const [clone .isra.0] () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#4  0x00007fffaebc44c9 in torch::jit::fuser::cuda::(anonymous namespace)::IrParser::parse() () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#5  0x00007fffaebc5124 in torch::jit::fuser::cuda::parseJitIR(std::shared_ptr<torch::jit::Graph> const&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#6  0x00007fffaeafd776 in torch::jit::fuser::cuda::GraphCache::createFusion(std::shared_ptr<torch::jit::Graph> const&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#7  0x00007fffaeafdb7b in torch::jit::fuser::cuda::GraphCache::GraphCache(std::shared_ptr<torch::jit::Graph> const&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#8  0x00007fffaeb81d3a in torch::jit::fuser::cuda::compileCudaFusionGroup(torch::jit::Node*)::{lambda()#1}::operator()() const [clone .isra.0] () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#9  0x00007fffaeb82389 in torch::jit::fuser::cuda::compileCudaFusionGroup(torch::jit::Node*) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#10 0x00007fffaea6aeb8 in torch::jit::fuser::cuda::(anonymous namespace)::compileFusionRecursive(torch::jit::Block*) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#11 0x00007fffaea6ae52 in torch::jit::fuser::cuda::(anonymous namespace)::compileFusionRecursive(torch::jit::Block*) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#12 0x00007fffaea7e072 in torch::jit::fuser::cuda::CudaFuseGraph(std::shared_ptr<torch::jit::Graph>&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so
#13 0x00007fffba7f9c3c in torch::jit::ProfilingGraphExecutorImpl::runNoGradOptimizations(std::shared_ptr<torch::jit::Graph>&, unsigned long) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fffba7fe7b0 in torch::jit::ProfilingGraphExecutorImpl::runProfilingOptimizations(std::shared_ptr<torch::jit::Graph>&, unsigned long) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fffba7ff58d in torch::jit::ProfilingGraphExecutorImpl::getOptimizedPlanFor(std::vector<c10::IValue, std::allocator<c10::IValue> >&, c10::optional<unsigned long>) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007fffba7ffc41 in torch::jit::ProfilingGraphExecutorImpl::getPlanFor(std::vector<c10::IValue, std::allocator<c10::IValue> >&, c10::optional<unsigned long>) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007fffba7b9afd in torch::jit::GraphExecutorImplBase::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) () from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007fffc167d63e in torch::jit::runAndInsertCall(torch::jit::Function&, torch::jit::tuple_slice const&, pybind11::kwargs const&, c10::optional<c10::IValue>, std::function<torch::jit::Value* (torch::jit::Graph&, torch::jit::MatchedSchema const&)> const&) ()
   from /home/rds4/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_python.so

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Sep 12, 2022

Ah, it is failing on these permute dimensions:

  1. ((0, -2, -1, 1),)
  2. ((),)

I will push a fix.

@zasdfgbnm
Copy link
Collaborator

Thanks for fixing!

Add nullptr input tensor checks
@rdspring1
Copy link
Collaborator Author

@zasdfgbnm I fixed the segmentation errors, but now I am seeing this error.

RuntimeError: entry_it != disjointSetMap().end() 
INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/disjoint_set.h":259, 
please report a bug to PyTorch. 
Strict mapping failed on element: T0_g[ 0 ] either an error occured, or non strict mapping should have been used.

@zasdfgbnm
Copy link
Collaborator

I will take a look

@rdspring1
Copy link
Collaborator Author

I forgot to mention that the failure occurs when running all the permute correctness tests.

python test_jit_cuda_fuser.py -k test_nvfuser_correctness_permute

@zasdfgbnm
Copy link
Collaborator

The entry_it != disjointSetMap().end() error is fixed, but still it is a failure. The root cause seems to be because the fusion is empty:

Inputs:
  T0_g[ 0 ], bool
Outputs:
  T0_g[ 0 ], bool

%kernel_math {
}

@zasdfgbnm
Copy link
Collaborator

Is it possible to skip empty fusion from the TorchScript side?

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting a request change to avoid accidental merge

@rdspring1
Copy link
Collaborator Author

There isn't a pass to skip empty fusions yet. Do we know why there are empty fusion groups happening?

@zasdfgbnm
Copy link
Collaborator

There isn't a pass to skip empty fusions yet. Do we know why there are empty fusion groups happening?

I think it is explicitly written in the tests:

def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
cases = [((1, 2, 3, 4), (0, 2, 3, 1)),
((1, 2, 3, 4), (0, -2, -1, 1)),
((), ()),
((1, 2, 3, 4), (2, 1, 3, 0))]
for shape, args in cases:
yield SampleInput(make_arg(shape), args=(args,))

@zasdfgbnm
Copy link
Collaborator

Related issue/PR:
#1808
#1752
#1835
Let me check if I can do anything to unblock this test

@rdspring1
Copy link
Collaborator Author

@zasdfgbnm I fixed the empty fusion by returning set(x) instead of x when the input is a scalar tensor.

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Sep 13, 2022

@rdspring1 please merge if you think it is OK to do so

@rdspring1 rdspring1 merged commit e96aacf into devel Sep 13, 2022
@rdspring1 rdspring1 deleted the transpose_python branch September 13, 2022 20:31
jjsjann123 added a commit that referenced this pull request Nov 9, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d1 Fix sync map (#2047)
f5bca33 Bank conflict checker improvements (#2032)
d2ca7e3 Minor update on cp.async code generation. (#1901)
d36cf61 Test file cleanup (#2040)
0b8e83f Allow non-root trivial reductions (#2037)
a2dfe40 Fix vectorize size calculation (#2035)
e040676 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b removing ci workflow (#2034)
40e2703 Reduction rand like patch (#2031)
bc77266 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1 Cleanup trivial reduction workarounds (#2006)
e4b6585 Trivial forwarding (#1995)
1a0e355 Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6 Enable output allocation cache (#2010)
35440b7 Patching bn inference (#2016)
0f9f0b4 Add matmul benchmark (#2007)
45045cd Enable tests previously disabled due to an aliasing bug (#2005)
967aa77 Contiguous indexing for View operations (#1990)
a43cb20 Make inlining even more modular (#2004)
dc45835 Test util cleanup (#2003)
3ca21eb More strict validation (#2000)
a7a7d57 Fix build problem (#1999)
fc235b0 Just fixes comments (#1998)
482386c cleanup (#1997)
4cbe0db Improve divisible split detection (#1970)
42ccc52 Minor build fix. (#1996)
fcf8c09 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6d Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f5 Minor cleanup lower_unroll.cpp (#1994)
1d9858c Minor cleanup (#1992)
f262d9c Add support for uniform RNG (#1986)
eb1dad1 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c Add support for some empty fusion (#1981)
eabe8d8 Segment self mapping fusions (#1954)
e96aacf Enable Transpose operation (#1882)
425dce2 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a6 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32c Minor fix (#1967)
bd93578 Enable transpose scheduler (#1927)
b7a206e Move scheduler vectorize utilities into their own file (#1959)
d9420e4 View scheduling (#1928)
c668e13 Upstream push ci fixes (#1965)
c40202b Fix dump effective bandwidth (#1962)
93505bc WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb33 Improve the comments at the beginning of index_compute.h (#1946)
f7bc341 Remove unused variables (#1955)
df3393a Some cleanup (#1957)
7d1d7c8 TVDomainGuard factory (#1953)
357ba22 Fill allocation with nan on tests (#1956)
8eafc54 Fix detection of unmappable root domains (#1952)
90a51f2 Some indexing cleanups, Add eye support (#1940)
ddc01e4 Exclude unsupported data types (#1951)
992e17c test the groups the same order as they are merged (#1949)
208262b Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38 Merge pull request #1945 from csarofeen/master_merge_0828
6310948 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bc Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c0 Fix arange when step is negative (#1942)
89330aa Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: pytorch#87779
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants