Skip to content

Commit

Permalink
[SparseTIR] Sparse format/block/buffers (#465)
Browse files Browse the repository at this point in the history
* upd

* normalize

* fix

* upd

* upd

* upd

* Readd instruction SampleShapeGenericTiles (#459)

* [Backport] MatchBuffer, BufferLocator & GetBlockReadWriteRegion (#460)

* [TensorIR] Support for match_buffer from subregion (#8585)

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
# Conflicts:
#	python/tvm/script/special_stmt.py
#	python/tvm/tir/transform/transform.py
#	src/tir/analysis/block_access_region_detector.cc
#	src/tir/analysis/buffer_access_lca_detector.cc
#	src/tir/transforms/lower_match_buffer.cc
#	tests/python/integration/test_lower.py
#	tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
#	tests/python/unittest/test_tir_analysis_get_block_access_region.py
#	tests/python/unittest/test_tir_lower_match_buffer.py
#	tests/python/unittest/test_tir_transform_compact_buffer_region.py
#	tests/python/unittest/test_tvmscript_error_report.py

* [TIR] Fix opaque access in buffer locator pass and match_buffer in region detector (#8855)

* init

* fix

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* address

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* [TIR] GetBlockReadWriteRegion (#8875)

* [TIR] GetBlockReadWriteRegion

* Fix black issue

* Use constant reference for the interface

* Fix lint issue

* Catch the correct error class in logical layout test

Co-authored-by: Siyuan Feng <hzfengsy@vip.qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>

* [BugFix] Fix Conv2d TensorCore Demo (#461)

* [Backport] LowerWarpMemory: remove unneeded shuffle when accessing from the same thread (#464)

* Lower logical intrin and end-to-end demo (#448)

* [WIP] Logical Layout lowering

* add intrin

* Logical inntrin lowering

* e2e demo

* LowerLogicalIntrin

* Remove num groups

* remove old demo

* rebase

* fix

* lower intrin pass

* Support nested software pipelining (#463)

* Support nested software pipelining

* Update test_schedule_software_pipeline.py

* upd

* upd

Co-authored-by: Bojian Zheng <bojian.zheng@mail.utoronto.ca>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Siyuan Feng <hzfengsy@vip.qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored Sep 4, 2021
1 parent 5135eb9 commit 23caab4
Show file tree
Hide file tree
Showing 63 changed files with 2,331 additions and 647 deletions.
2 changes: 1 addition & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class IntSet : public ObjectRef {
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true
* \note It is guaranteed that IntSet::FromRange(r).MatchRange(r) == true
* \return true if we can prove they are the same.
*/
bool MatchRange(const tvm::Range& r) const;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class IterSplitExpr : public IterMapExpr {
*/
TVM_DLL explicit IterSplitExpr(IterMark source);
/*!
* \brief constructor from just source.
* \brief constructor from source and scale.
* \param source The source expression.
* \param scale The additional scaling factor.
*/
Expand Down
19 changes: 15 additions & 4 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
Expand All @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constrain
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);
TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Auto detect the block read/write region according to its body stmt. An opaque access will
* be counted as both a read and a write access
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer
* \return An array only consisting of the read regions and write regions of the input block
*/
TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Calculate the expresion complexity based on number of symbols it contains.
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@ TVM_DLL const Op& atomic_add();
*/
TVM_DLL const Op& tvm_memcpy_async();

/*!
* \brief tvm intrinsic for mfma instruction
*/
TVM_DLL const Op& tvm_mfma_sync();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#include <tvm/tir/schedule/state.h>

namespace tvm {

class Target;

namespace tir {

using TRandState = support::LinearCongruentialEngine::TRandState;
Expand Down Expand Up @@ -239,6 +242,18 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) = 0;
/*!
* \brief Sample the factors to tile a list of LoopRV's
* \param loop_rvs The loops to be tiled
* \param ns The number of loops after tiling
* \param max_innermost_factor The maximum factor in the innermost loop, -1 if disabled
* \param decision The sampling decision
* \return An array of n random variables, the result of sampling
*/
virtual Array<Array<ExprRV>> SampleShapeGenericTiles(const Array<LoopRV>& loop_rvs,
const std::vector<int>& ns,
const Target& target, int max_innermost_factor,
Optional<Array<Array<Integer>>> decision = NullOpt) = 0;
/*!
* \brief Sample an integer given the probability distribution
* \param candidates The candidates
Expand Down
66 changes: 66 additions & 0 deletions include/tvm/tir/sparse/block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/sparse/block.h
* \brief Sparse Block in Sparse TIR.
*/

#ifndef TVM_TIR_SPARSE_BLOCK_H_
#define TVM_TIR_SPARSE_BLOCK_H_

#include <tvm/tir/sparse/format.h>

namespace tvm {

namespace tir {

namespace sparse {

/*!
* \brief Class of sparse block.
* \example
* with tir.sp.block([i, j, k], [False, False, True]) as [vi, vj, vk]:
* pass
* with tir.sp.block([i, j, k], [False, False, True], [(0, 1), (2,)]) as [vi, vj, vk]:
* pass
*/
class SparseBlockNode : public Object {
public:
AxisRef root;
Array<Axis> axes;
Array<Array<int>> fused_groups;
Array<bool> is_reduce_axis;
static constexpr const char* _type_key = "tir.sp.SparseBlockNode";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, Object);
};

class SparseBlock : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, ObjectRef, SparseBlockNode);
}


} // namespace sparse

} // namespace tir

} // namespace tvm

#endif // TVM_TIR_SPRASE_BLOCK_H_
73 changes: 73 additions & 0 deletions include/tvm/tir/sparse/buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/sparse/buffer.h
* \brief Sparse buffer data structure in Sparse TIR.
*/
#ifndef TVM_TIR_SPARSE_BUFFER_H_
#define TVM_TIR_SPRASE_BUFFER_H_

#include <tvm/tir/sparse/format.h>
#include <tvm/tir/buffer.h>

namespace tvm {

namespace tir {

namespace sparse {

/*!
* \brief Class of sparse buffer.
*/
class SparseBufferNode : public Object {
public:
/* Root of Axis Dependency Tree. */
AxisRef root;
/* Axes */
Array<Axis> axes;
/* Number of dimensions */
int ndim;
/* Buffer corresponding to flattened value */
Buffer data;
/* Buffer corresponding to indices pointer */
Array<Buffer> indptr;
/* Buffer of column indices */
Array<Buffer> indices;

static constexpr const char* _type_key = "tir.sp.SparseBufferNode";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
};

/*!
* \brief Managed reference to SparseBufferNode.
* \sa SparseBufferNode
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

} // namespace sparse

} // namespace tir

} // namespace tvm

#endif // TVM_TIR_SPARSE_BUFFER_H_
Loading

0 comments on commit 23caab4

Please sign in to comment.