Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size #5307

Merged
merged 2 commits into from
Apr 17, 2020

Conversation

roastduck
Copy link
Contributor

Pass lower_warp_memory lowers memory bound to "warp" scope into the warp shuffle intrinsic. Currently, this pass only supports the situation where the extent of threadIdx.x equals to the warp size. However, CUDA's __shfl has a 3rd parameter width to shuffle variables in half (or 1/4, 1/8, 1/16) of a warp. This PR uses this extra parameter to enable Pass lower_warp_memory when the extent of threadIdx.x is less than the warp size.

Changes:

  1. Add a 3rd parameter width and a 4th parameter warp_size to TVM intrinsic tvm_warp_shuffle. The 4th parameter warp_size is used to help a Code Generator to decide whether a width is legal. For example, the OpenCL backend dose not support the width parameter, so it has to check whether width == warp_size. Since currently lower_warp_memory is the only pass that utilize tvm_warp_shuffle, this change will not break any dependencies.
  2. Code Generators that lowers tvm_warp_shuffle are modified. Currently, the only two affected Code Generators are CUDA and OpenCL.
  3. In lower_warp_memory, find the value of width first, and then alter the IR base on width, instead of based on warp_size. Then, it generate the modified tvm_warp_shuffle intrinsic.
  4. A test which runs lower_warp_memory with 1/2 warp size is added.

Can @tqchen, @ZihengJiang or @ajtulloch make a review or suggest any other reviewers?

@tqchen
Copy link
Member

tqchen commented Apr 11, 2020

@tqchen
Copy link
Member

tqchen commented Apr 11, 2020

Thanks @roastduck. I wonder if we can also discuss the alternative abstractions. Right now the abstraction seems to suggest that conceptually the size of the warp is reduced to half(as the shuffle size). However, another way to view it would be to keep the size of the warp to be fixed(32), but support the index access pattern of the subgroups, for example, the canonical form below describes a shuffle in the group of 4

A[wi] = B[(wi/4)*4+ ((wi % 4) +1) %4]

@roastduck
Copy link
Contributor Author

Thanks @roastduck. I wonder if we can also discuss the alternative abstractions. Right now the abstraction seems to suggest that conceptually the size of the warp is reduced to half(as the shuffle size). However, another way to view it would be to keep the size of the warp to be fixed(32), but support the index access pattern of the subgroups, for example, the canonical form below describes a shuffle in the group of 4

A[wi] = B[(wi/4)*4+ ((wi % 4) +1) %4]

In the alternative approach, __shfl(x, (threadIdx.x + 1) % 4, 4) becomes __shfl(x, threadIdx.y * 4 + (threadIdx.x + 1) % 4) (and threadIdx.z might also be involved). Is my understanding right?

I think the good thing is better compatibility for OpenCL. By this approach, we can support OpenCL using the old 2-parameter intrinsic.

And I think the bad thing comes with CUDA's new shuffle API. __shfl has actually been deprecated by CUDA, and we will have to switch to the new __shfl_sync API. The new API requires an explicit argument mask to specify which threads are active during this shuffle. In the approach of this PR, we will only need to calculate the activeness within a partial (say 1/2, 1/4, etc) warp, which can be calculated from the if nest, given the thread indices of the current partial warp. But in the alternative approach, we will need to calculate the activeness within a whole warp, which means other threads outside the current partial warp will be involved. It may bring a lot of complexity, and even run time overhead when there are dynamic conditions.

To better support both CUDA and OpenCL, maybe we can use both of the approaches.

@tqchen
Copy link
Member

tqchen commented Apr 12, 2020

Can we still translate to __shfl(x, (threadIdx.x + 1) % 4, 4) in the alternative approach? given that the pattern(wi= warp_index) directly corrresponds to the related pattern

I see, in this case, it would be great to give a bit more discussion and thought about the new programming model.

  • Does that mean we are using a smaller "virtual warp"?(which brings the restriction of not being able to use larger amount of the shuffle)

@roastduck
Copy link
Contributor Author

I see, so you are talking about the programming model instead of the code generation. In the alternative model, a user doesn't cache a buffer to "warp" scope. Instead, a user accesses variables from other threads directly through indices. Have I correctly understood your idea?

What are the advantages of the new model over the current one? You said the current model conceptually reduces a warp to half. Do you mean that it is hard to mix full-warp shuffling and half-warp shuffling for a single buffer? I think this problem can be solved by improving the detection algorithms used in the lower_warp_memory pass, without switching to a new programming model.

If the expressing ability of the two models are equivalent, I prefer the current "bind to warp scope" model, because it is more general and there is no need to introducing a new language feature.

@tqchen
Copy link
Member

tqchen commented Apr 12, 2020

I actually meant the "bind to warp scope mode". What worries me a bit is that in our original convention, threadIdx.x always means warp index, and the previous comment was about whether or not we should keep threadIdx.x still equal to be warp size, but to detect patttern
A[wi] = B[(wi/4)*4+ ((wi % 4) +1) %4] (where wi = threadIdx.x) for the half shuffle case

@roastduck
Copy link
Contributor Author

roastduck commented Apr 12, 2020

Let me sort out the two approaches.

  • In the current approach, the lowering pass directly detects shuffles on threadIdx.x, which has an arbitrary extent. It may be difficult to detect all the complex use cases, for example when there are mixed half-warp and full-warp shuffling.
  • In the alternative approach, we assume threadIdx.x == warp size when detecting shuffles. Therefore, either we require users to set the extent of threadIdx.x to the warp size (and then we close this PR), otherwise we may have to add an additional pass to convert threadIdx.x to match the warp size.

@roastduck
Copy link
Contributor Author

roastduck commented Apr 12, 2020

One way to perform a half-warp shuffle while keeping extent(threadIdx.x) == warp size is like this:

Suppose we have extent(threadIdx.x) == 16, we first split threadIdx.y with factor 2, and then fuse that 2 with threadIdx.x, so the extent of threadIdx.x becomes 32.
2. Then we can perform any lowering procedure with the assumption of threadIdx.x == 32.

Either we require users to write such a schedule, otherwise we modify the schedule for them. I think either way is not intuitive enough for users. It may also hinder debugging, and debugging in TVM is already difficult. Note that we have to modify threadIdx.x in more than one scopes, in order to keep the thread index consistent. Here is a more complex example, which is simplified from the algorithm I am currently working on.

// a is shaped (n)
// b is shaped (16)
// c is shaped (n, 16)
// extent of threadIdx.x == 16
for (i.outer = 0; i.outer < n; i.outer += 16) {
    if (i.outer + threadIdx.x < n) {
        a.warp[i.outer + threadIdx.x] = a[i.outer + threadIdx.x]; // (1)
    }
    for (i.inner = 0; i.inner < min(16, n - i.outer); i.inner++) {
        c[i.outer + i.inner, threadIdx.x] += a.warp[i.outer + i.inner] * b[threadIdx.x]; // (2)
    }
}

threadIdx.x in both statement (1) and (2) will be fused to 32, which is a major modification to the schedule. Users may meet difficulties with this schedule.

@tqchen
Copy link
Member

tqchen commented Apr 12, 2020

Thanks for the great discussions so far, it would be great to have a discussion(perhaps in the forum) about the possible conventions to define th warp, just to clarify further.

  • A0: Enforce threadIdx.x == warp_size, and make it the warp index
    • Add a subwarp shuffle detection directly on to the warp shuffle to detect sub-warp shuffle pattern like B[(wi/4)*4+ ((wi % 4) +1) %4] whcih corresponds to a sub-warp shuffle.
  • A1: Warp virualization -- Allow threadIdx.x to be at subwarp level. (this PR)

From what I see, A0 might introduce a bit more complexity, but allows a mixture of subwarp and full warp shuffle patterns.

A1 is can be viewed as "virtualizing the warps" by having more virtual warps that acted at the sub-warp level, and use a single warp to simulate them, we cannot mix that with full warp shuffle.

It would be great if we can think a bit about how to describe these different clearly and document them in the comment of the code, and possibly in the future developer docs.

If my understanding is correct, and we can document A1's concept clearly, I can go ahead and merge this PR first, then we do followup discussons.

@roastduck
Copy link
Contributor Author

I agree with the description of A0 and A1, and now I think there may be an A2 in the design space.

  • A2: Improve the shuffle detection. We can analyze all the three thread axes, instead of only threadIdx.x. If the detection is strong enough, we can do whatever shuffle we want without limiting threadIdx.x to be the warp axis.

I think A2 may require less complexity than A0, where we should alter the thread indices.

And for documentation, am I suppose to refine the comments in the code, or start a new documentation page for the topic of warp memory?

@tqchen
Copy link
Member

tqchen commented Apr 13, 2020

We can start with refining the comments in the code. More docs about warp semantics is always more than welcomed as followup PRs

@roastduck
Copy link
Contributor Author

Added a bit more descriptions for width. I think it's enough to explain the sub-warp concept.

@wpan11nv
Copy link
Contributor

Are there any real-world usages of this sub-warp scope? Should we upgrade TVM to emit new warp level APIs and then think about extensions?

@tqchen
Copy link
Member

tqchen commented Apr 15, 2020

@wpan11nv can you try to comment about the new warp API and potential backward compact problem of this extension?

@roastduck
Copy link
Contributor Author

Can we merge this PR first before discussing about a new API?

@tqchen tqchen merged commit 4b5f324 into apache:master Apr 17, 2020
@tqchen
Copy link
Member

tqchen commented Apr 17, 2020

OK, this is merged for now, let us open new RFC threads for the new warp API in CUDA. @roastduck I think @wpan11nv meant the API in CUDA instead of TVM's convention.

Thanks @roastduck @wpan11nv

dpankratz pushed a commit to dpankratz/incubator-tvm that referenced this pull request Apr 24, 2020
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
dhruvaray pushed a commit to dhruvaray/incubator-tvm that referenced this pull request Apr 28, 2020
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
dhruvaray added a commit to dhruvaray/incubator-tvm that referenced this pull request Apr 28, 2020
* [Relay][Frontend][TFLite] Add parser support for shape and range

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

* [RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops (apache#5316)

* [RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops

* Review comments

* [TIR] Refactor MakePackedAPI to target dependent stage. (apache#5326)

Previously MakePackedAPI was in the target independent stage,
but never the less requires the device_type information that will be
binded at a later target dependent stage.

The previous implementation was due to the limitation of LoweredFunc
which can not carry buffer_map info(so they have to be lowered right away).
This is no longer the case after the unified IR refactor.

This PR migrates MakePackedAPI to a target dependent stage
and removes the un-necessary BindDevice pass.

* [RELAY] Remove re-exports of tvm.transform (apache#5337)

* [LLVM] Use llvm::FunctionCallee in IRBuilder::CreateCall with LLVM 11+ (apache#5338)

The older variants of CreateCall have been deprecated and were recently
removed from LLVM. This caused compilation failures.

* [CI] Fix build.sh to propagate --network=host to the docker build command (apache#5336)

* when passing --net=host to build.sh it needs to be also
   sent as --network=host to "docker build", so that both
   build and run will use the same network configuration

* [Runtime][Relay][Cleanup] Clean up for memory pass to enable heterogenous execution support. (apache#5324)

* Cleanup type pack and unpack for tuples.

* Clean up the memory_pass using common helpers

* Clean up memory.cc

* Refactor pass

* Add doc strings

* Fix CPPlint

* Fix PyLint

* Fix

* Apply suggestions from code review

Co-Authored-By: Zhi <5145158+zhiics@users.noreply.github.com>

* Fix typo

Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com>

* Windows Support for cpp_rpc (apache#4857)

* Windows Support for cpp_rpc

* Add missing patches that fix crashes under Windows

* On Windows, use python to untar vs wsl

* remove some CMakeLists.txt stuff

* more minor CMakeLists.txt changes

* Remove items from CMakeLists.txt

* Minor CMakeLists.txt changes

* More minor CMakeLists.txt changes

* Even more minor CMakeLists.txt changes

* Modify readme

* [PYTORCH]Take, Topk op support (apache#5332)

* [PYTORCH]take, topk op support

* Ci Failure fix

* [TOPI] Using x86 schedules for ARM conv2d. (apache#5334)

* [TOPI] Improve get_valid_count and nms performance for CUDA (apache#5339)

* get_valid_count updated to have correct results

* speedup nms

* update nms

* revert back nms

* recover one test for get_valid_count

* [PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (apache#5335)

* [TIR] Remove ProducerConsumer and AllocateNode::new_expr (apache#5333)

* [TIR] Remove ProducerConsumer and AllocateNode::new_expr

This PR removes two legacy IR parts in TIR that are deprecated.

ProducerConsumer node only serves as a hint markup and may no longer be
informative after extensive transformations in the pass.
If necessary, we can add related info via AttrStmt.

The new_expr field in the AllocateNode is deprecated since it can just be
replaced by a LetStmt.

- Remove dependencies of passes on ProducerConsumer.
- Remove ProducerConsumer from the IR.
- Remove the deprecated fields (new_expr, free_function) from AllocateNode.

* Fix additional testcases

* [BYOC] Prevent duplicate outputs in subgraph Tuple (apache#5320)

* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential

* [Tutorial, QNN] Add tutorial for loading quantized PyTorch model (apache#5321)

* add pytorch tutorial code and doc stub

* add more docs

* formatting, more docs

* typo fix

* try make sphinx happy

* add performance section

* type and nit fix

* format fix

* [DOCS] Bring relay docs to the top-level flat view (apache#5343)

- Changes most of the relay docs to use autosummary.
- Bring relay API docs to the top-level flat view for easier discovery
- Removed a few cases of re-exports.

* [TOPI][PYTORCH]Logical & Bitwise operator support (apache#5341)

* [RELAY][BYOC] Register pattern tables from external codegens (apache#5262)

* [RELAY][BYOC] Register pattern tables from external codegens

This adds utility functions to support registering
and retrieving pattern tables used by MergeComposite for
external codegens.

Change-Id: I5be165a321440e48b15ff6aff4970e0c67496aaa

* Updated DNNL tests to use pattern table mechanism

* Removed pattern table standalone test

* Change reg to _op

* [RUNTIME][CRT] support DLTensor whose ndim == 0 (apache#5344)

Signed-off-by: windclarion <windclarion@gmail.com>

* [BYOC][FIX] Fix typo in "default" (apache#5348)

Default annotations were incorrectly being named 'defualt'
which results in them not being removed in PartitionGraph.

* enable tsim and fsim for GPU build (apache#5352)

* [CRT]Compilation warnings fixed for 32bit and 64bit compilation (apache#5349)

* [PYTORCH]Tensor creation ops support (apache#5347)

* [Hexagon] Add hexagon_posix.cc to TVM/RT sources in the right place (apache#5346)

This file was added before the variable with TVM/RT was initialized.
The initialization overwrote the addition.

* [TOPI-ARM] Do not alter layout if layout is NHWC (apache#5350)

* [TOPI-ARM] Do not alter layout if layout is NHWC

* Add test.

* [TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size (apache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory

* [RELAY][PYTORCH]GroupNorm op support added (apache#5358)

* docker: Drop caffe2 download progess bars (apache#5359)

Change-Id: Ia15c3c8f41f75423814e559f6fdb062098f19464

* fix fuse over functions that are handled by external codegen (apache#5365)

* [RUNTIME] FastRPC interface for Hexagon runtime (apache#5353)

* [RUNTIME] FastRPC interface for Hexagon runtime

Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com>
Co-authored-by: Krzysztof Parzyszek <kparzysz@quicinc.com>

* Explain store offset in a comment in launcher

Co-authored-by: Abhikrant Sharma <quic_abhikran@quicinc.com>
Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com>

* [TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR pass manager. (apache#5364)

- Migrate BoundCheckers and Simplify
- Migrate RewriteUnsafeSelect and RemoveNoOp
- Migrate UnrollLoop and StorageRewrite
- Migrate InjectDoubleBuffer and InjectVirtualThread
- Migrate LoopPartition and Vectorize
- Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin

We still keep ir_pass registerations for now.
Need a separate PR to refactor the parts before the StorageFlatten.

* [TIR] Fix lower_warp_memory when there are >1 warp buffers (apache#5368)

* fix recursion in lower_warp_memory

* post-order mutation

* Add cuda target check to dense tensorcore schedule. (apache#5376)

* Remove developer facing api from frontend exports. (apache#5375)

* [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (apache#5372)

* [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.

te::Tensor is an useful object for tensor expression, but brings
un-necessary reverse dependency in TIR nodes such as Provide and Realize.

This PR is a first step to remove this dependency. We will use Buffer in all the places
where the te::Tensor was used. The rough correspondence are:

- Provide -> BufferStore
- Realize -> BufferRealize
- HalideCall -> BufferLoad.

After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR
at any point of the optimizations. Buffer will serve as the abstraction for the TIR data
models to represent the intermediate storages and their constraints.

We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum.
Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR
generated by TE(which contains these nodes) to the TIR.

The TIR optimizations are now mostly migrated to to the pass manager.
Followup PRs are needed to migrate the remaining few passes.

* Fix dev tutorial

* [PYTORCH]Unary Ops (apache#5378)

* [TIR][REFACTOR] RewriteForTensorCore -> te/schedule (apache#5379)

* [TIR][REFACTIR] RewriteForTensorCore -> te/schedule

RewriteForTensor depends on the schedule information, which makes it differ
from a typical pass(which should get all the information from the input TIR).

As a result, we refactor it as a SchedulePostProc step for now.
We should revisit it later as we introduce more support for tensor core patterns in the TIR.

* Fix VTA to fit the new IR Pattern

* [Blocksparse] Pipeline for lowering dense model to sparse-dense (apache#5377)

* [REFACTOR][TE] Inline -> te/schedule/operation_inline.h (apache#5386)

Rationale: inline is a transformation used in te to
rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass.

Also removed the python test as the test is covered by stage.compute_inline.

* [ARITH] Remove the legacy Simplify, migrate to Analyzer. (apache#5385)

The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer.
This PR removes these functions and migrated every place that requires
simplification to enforce Analyzer creation.
The new API would encourage more Analyzer sharing and potentially enable
context-aware analyzer-based simplification.

* [ARITH] Remove legacy const pattern functions (apache#5387)

* Add ability to have multiple copies of same input to onnx_inputs. (apache#5389)

* [Topi, ARM] Disbale Winograd for quantized tensors. (apache#5363)

* [Topi, ARM] Disbale Winograd for quantized tensors.

* Relaxing float

* Fix test_ir_type. (apache#5390)

* The void return type is not None/nullptr, it's VoidType or
   TupleType([]).

* Tf2 test fixups (apache#5391)

* Fix oversight in importing tf.compat.v1 as tf.

* Actually disable test for lstm in TF2.1

Since the testing framework actually uses pytest, the version
check needs to be moved.

* [PTYTHON] Migrate VTA TIR passes to the new pass manager. (apache#5397)

* [LLVM] Use ArrayRef<int> in calls to CreateShuffleVector (apache#5399)

This switch was made in LLVM 11. Previously this function was expecting
mask indices of type uint32_t. This variant is now deprecated.

* [KERAS]Minimum & AlphaDropout op support (apache#5380)

* Factor out import of common tflite.Operator in tflite frontend. (apache#5355)

* Restructure imports in tflite frontend.

These python modules are needed for every tflite file parsed.
Factorize out imports of the common most ones.

Now that the import of operator is common, asserts can be commonized.

Loses 473 lines of duplication.

* Only restrict to tflite.Operator

* [Fix] Remove the duplicate PrintIR pass in Relay (apache#5403)

* Update dmlc-core to latest (apache#5401)

* [TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (apache#5400)

Substitute now takes a std::function to customize more replacing behaviors.

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>

* [Relay] Fix memory leak when accessing NDArray (apache#5413)

* Customize SI prefix in logging (apache#5411)

* Customize SI prefix in logging

* Include unit test

* [LLVM] Replace calls to Type::getVectorNumElements (apache#5398)

This function has recently been removed from LLVM 11. Use alternative
way to obtain vector element count (VectorType::getNumElements) which
works for all LLVM versions.

* Don't remove() TempDirectory in __del__ after atexit hook runs. (apache#5414)

* Use atexit to remove TempDirectory before interpreter shutdown.
 * Can't rely on complex functions from __del__ anyway.
 * Fixes warning message on my box:
       Exception ignored in: <function TempDirectory.__del__ at 0x12be10680>
       Traceback (most recent call last):
        File ".../tvm/python/tvm/contrib/util.py", line 55, in __del__
        File ".../tvm/python/tvm/contrib/util.py", line 51, in remove
        File "/usr/local/opt/python/Frameworks/Python.framework/Versions/3.7/lib/python3.7/shutil.py", line 509, in rmtree
        AttributeError: 'NoneType' object has no attribute 'path'

* [TIR][REFACTOR] Remove ir_pass in favor of analysis/transform. (apache#5415)

This PR removes ir_pass(old style pass functions) in favor
of analysis/transform(new style pass manager).

* [RUNTIME][CONTRIB] CoreML Runtime (apache#5283)

* [RUNTIME][CONTRIB] CoreML Runtime

* fix lint

* fix CI

* use xcrun to compile coreml model

* [DOCS] Migrate HLS documents from md to rst (apache#5419)

* fix [RUNTIME][VULKAN] vkBuffer released before memory copy command send to GPU (apache#5388) (apache#5418)

* [Frontend] Asymmetric padding of convolution support (apache#4803)

* [cuDNN] Add cuDNN grouped convolutions support (apache#5319)

Signed-off-by: Wei Pan <weip@nvidia.com>

* [CI] Migrate Tensorflow and Tensorflow lite in CI to  2.1.0 (apache#5392)

* Migrate Tensorflow and TFLite in the CI up to 1.15.2

The latest stable version of Tensorflow and Tensorflow lite
in the 1.x series is 1.15.2. The tflite frontend is receiving
support for versions of tflite > 1.14 but there is no consistent
testing.

There are 2 failures already in the source base with tf 1.15
and I'm concerned this will just get exacerbated over time
if we don't have CI picking this up and I view this as a stepping
stone towards stepping CI to TF2.x.

The test failures that I have commented will get issues raised
for them as issues to be fixed.

* Comment out run of qnn_mobilenet_v3_net

This is another test that fails with TFlite 1.15.2

* Skip the qnn_mobilenet_v3 test in the pytest fashion.

* Switch docker versions to support Tensorflow 2.1.0

* Fix up pytest imports and usage.

* Skip these tests currently for Tensorflow 2.1.0

* [DOCS] Migrate some markdowns to rst, fix sphinx3 warnings (apache#5416)

* [DOCS] Migrate some markdowns to rst, fix sphinx3 warnings

* Add note block

* [BYOC] Use Non-Recursive Visitor/Mutator (apache#5410)

* Non-Recursive AnnotatedTarget and MergeAnnotation

* Non-Recursive AnnotatedRegionSet and RegionMerger

* [RFC] Pytest environment improvements (apache#5421)

* [RFC] Pass pytest options globally.

In many places having a global pytest flag is useful . For me with the
build and test of tvm , I would like to be able to globally pass in
pytest options as part of development flow or CI flows where one would
like to measure other things regularly that need measurements including
pytest coverage data that I would like to experiment with across the stack.

This has been achieved with an additional setup-pytest-env.sh file in
tests/scripts rather than putting in something in every single task test
script and something I would like to avoid.

This now means the -v option to pytest is superfluous. I did consider
having a pytest.ini file but that doesn't allow me to pass any old
environment variable in and this seems to be the compromise.

* Improve other use case documentation

* Rationalize pytest environment.

* Remove the setting from docker/with_same_user.
* Take the opportunity to migrate common PYTHONPATH and
TVM_PATH into the common environment setting.

* Fixup vta fsim

* Be more explicit with common PYTHONPATH

* Fix python path for task_python_vta_fsim.sh properly

* Fix nit in documentation.

* [MXNET]DepthToSpace & SpaceToDepth Operator (apache#5408)

* Add option to specify flatbuffers location (apache#5425)

* [FRONTEND][MXNET] support elemwise logic ops (apache#5361)

* [PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (apache#5426)

To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint.

We introduce a PyNativeObject class to handle this kind of object sub-classing
and updated the FFI to handle PyNativeObject classes.

* [PYTORCH]where, addcdiv, addcmul op support (apache#5383)

* [PYTORCH]Where, addcdiv, addcmul op support

* Review comments fixed

* [FRONTEND][TFLITE]Gather, StridedSlice op support added (apache#4788)

* [FRONTEND][TFLITE]Gather, StridedSlice op added

* Review comments fixed

* misc fixes for ROCm (pointer lifetime, runtime::String refactor) (apache#5431)

* Corrected TVM autotuning on GPU (apache#5432)

Added missing "tir" in tvm.tir.analysis.verify_gpu_code(f, kwargs)

* [RUNTIME][OBJECT] Introduce static slots for common objects. (apache#5423)

The _type_child_slots can be used to enable quick type checking optimization
by checking the whether the type index is within the bound.

This PR enables these static slots:

- Introduce a static assert to avoid the scenario when a developer forget to
  _type_child_slots when the field is set for the type's parent.
- Revamp and assign static type index to common runtime objects
- Add a DumpTypeTable call to allow developer monitor the current situation
  of type table and offers suggestions for the slots(ideally the slots equals
  the number of children so there is no overflow.

* [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (apache#5395)

* [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support

* Review comment fixed

* Gradient testcase added

* [PYTORCH]Rsub, Embedded, OneHot ops support (apache#5434)

* fix miopen pad (apache#5433)

* [TOPI,RELAY][TFLITE] Sparse to dense operator

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

* Add TopK to ONNX Frontend (apache#5441)

* Add TopK to ONNX Frontend

* respond to review comments

* [CodeGen] Cleanup generated code (apache#5424)

- remove unnecessary white spaces from storage kind
- do not start a new scope for vectorization as temporary
  variables are alll uniquely generated.

The above two changes make vectorized code much cleaner.

Signed-off-by: Wei Pan <weip@nvidia.com>

* [RELAY] Move frontend utils (apache#5345)

* [RELAY] Move frontend utils

The util file currently under frontend is used from
outside of frontend (in qnn/op/legalizations). This suggests
that the file should be pushed up to a higher level.

The benefit from this change is that importing qnn no longer
also imports all the frontends.

* Inline get_scalar_from_constant

Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4

* Remove util.py from Relay

Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96

* Shorten functions

Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2

* Line length

Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f

* [KERAS]Embedding layer (apache#5444)

* [Docs] VTA install doc migration from md to rst (apache#5442)

* Improve IntervalSet's floormod (apache#5367)

* use param name in documentation

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

* [ONNX]GatherNd, Round, IsNaN, IsInf (apache#5445)

* [relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dilate() (apache#5331)

* Add operation relay.nn.dilate() which calls topi.nn.dilate().

* Fix typo

* Set op pattern to injective

* sphinx doc errors fixed

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

* [Pytorch] fix translation of transpose when axis argument is as a list (apache#5451)

* incorporated code review comments

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

* Fixed indentation

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>

Co-authored-by: Samuel <siju.samuel@huawei.com>
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Krzysztof Parzyszek <kparzysz@quicinc.com>
Co-authored-by: Leandro Nunes <leandro.nunes@arm.com>
Co-authored-by: Jared Roesch <jroesch@octoml.ai>
Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com>
Co-authored-by: jmorrill <jeremiah.morrill@gmail.com>
Co-authored-by: Animesh Jain <anijain@umich.edu>
Co-authored-by: Leyuan Wang <laurawly@gmail.com>
Co-authored-by: Trevor Morris <trevmorr@amazon.com>
Co-authored-by: masahi <masahi129@gmail.com>
Co-authored-by: mbaret <55580676+mbaret@users.noreply.github.com>
Co-authored-by: windclarion <windclarion@gmail.com>
Co-authored-by: Tang, Shizhi <rd0x01@gmail.com>
Co-authored-by: Marcus Shawcroft <marcus.shawcroft@arm.com>
Co-authored-by: Abhikrant Sharma <quic_abhikran@quicinc.com>
Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com>
Co-authored-by: Josh Fromm <jwfromm@uw.edu>
Co-authored-by: shoubhik <shoubhikbhatti@gmail.com>
Co-authored-by: Bing Xu <antinucleon@gmail.com>
Co-authored-by: Andrew Reusch <areusch@octoml.ai>
Co-authored-by: Ramana Radhakrishnan <ramana.radhakrishnan@arm.com>
Co-authored-by: Haichen Shen <shenhaichen@gmail.com>
Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>
Co-authored-by: MORITA Kazutaka <morita.kazutaka@gmail.com>
Co-authored-by: samwyi <samwyi@yahoo.com>
Co-authored-by: Zhao Wu <zhaowu@apache.org>
Co-authored-by: Wei Pan <60017475+wpan11nv@users.noreply.github.com>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
Co-authored-by: Michal Piszczek <imichaljp@gmail.com>
Co-authored-by: Thomas Viehmann <tv.code@beamnet.de>
Co-authored-by: JishinMaster <francois.turban@gmail.com>
Co-authored-by: Matthew Brookhart <matthewbrookhart@gmail.com>
Co-authored-by: Thierry Moreau <tmoreau@octoml.ai>
Co-authored-by: yongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com>
Co-authored-by: notoraptor <notoraptor@users.noreply.github.com>
Co-authored-by: Nikolay Nez <34389970+n-nez@users.noreply.github.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 9, 2020
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 18, 2020
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 18, 2020
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants