Skip to content

Commit

Permalink
[TIR] GetBlockReadWriteRegion (apache#8875)
Browse files Browse the repository at this point in the history
* [TIR] GetBlockReadWriteRegion

* Fix black issue

* Use constant reference for the interface

* Fix lint issue
  • Loading branch information
MasterJH5574 authored and ylc committed Jan 13, 2022
1 parent 459ecfc commit 35ea46e
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 20 deletions.
19 changes: 15 additions & 4 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
Expand All @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constrain
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);
TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Auto detect the block read/write region according to its body stmt. An opaque access will
* be counted as both a read and a write access
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer
* \return An array only consisting of the read regions and write regions of the input block
*/
TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Calculate the expresion complexity based on number of symbols it contains.
Expand Down
24 changes: 23 additions & 1 deletion python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,29 @@ def get_block_access_region(
- second: write regions
- third: opaque regions
"""
return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore
return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore


def get_block_read_write_region(
block: Block, buffer_var_map: Dict[Var, Buffer]
) -> List[List[BufferRegion]]:
"""Auto detect the block read/write region according to its body stmt.
An opaque access will be counted as both a read and a write access
Parameters
----------
block: tvm.tir.Block
The block in which we are detecting read/write regions.
buffer_var_map : Dict[Var, Buffer]
The outside buffers which may access the block. Mapping from buffer var to the buffer
Returns
-------
result : List[List[BufferRegion]]
An array only consisting of the read regions and write regions of the input block
"""
return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore


def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int:
Expand Down
34 changes: 33 additions & 1 deletion src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,39 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()};
}

TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion);
Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map) {
// Step 1. Get all the read/write/opaque accesses in the input block.
Array<Array<BufferRegion>> access_regions = GetBlockAccessRegion(block, buffer_var_map);
// Step 2. Collect all the buffers that are opaquely accessed.
std::unordered_set<const BufferNode*> opaque_accessed_buffers;
for (const BufferRegion& opaque_access : access_regions[2]) {
opaque_accessed_buffers.insert(opaque_access->buffer.get());
}
// Step 3. Create new arrays of read/write regions.
Array<BufferRegion> new_read_regions;
Array<BufferRegion> new_write_regions;
new_read_regions.reserve(access_regions[0].size() + access_regions[2].size());
new_write_regions.reserve(access_regions[1].size() + access_regions[2].size());
for (const BufferRegion& read_access : access_regions[0]) {
if (!opaque_accessed_buffers.count(read_access->buffer.get())) {
new_read_regions.push_back(read_access);
}
}
for (const BufferRegion& write_access : access_regions[1]) {
if (!opaque_accessed_buffers.count(write_access->buffer.get())) {
new_write_regions.push_back(write_access);
}
}
for (const BufferRegion& opaque_access : access_regions[2]) {
new_read_regions.push_back(opaque_access);
new_write_regions.push_back(opaque_access);
}
return {new_read_regions, new_write_regions};
}

TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion);
TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion);

} // namespace tir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator {
Array<BufferRegion> reads = std::move(block->reads);
Array<BufferRegion> writes = std::move(block->writes);
if (!is_scope_root) {
Array<Array<BufferRegion>> inspected = GetBlockAccessRegion(block, buffer_var_map_);
Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, buffer_var_map_);
reads = std::move(inspected[0]);
writes = std::move(inspected[1]);
}
Expand Down
16 changes: 4 additions & 12 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ class BufferAllocationLocator : public StmtExprMutator {
/*init=*/NullOpt,
/*alloc_buffers=*/alloc_buffers);
ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
CollectReadWrite(opaque_block, &n->reads, &n->writes);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_);
n->reads = access[0];
n->writes = access[1];
BlockRealize realize({}, Bool(true), Block(n));
return std::move(realize);
}
Expand All @@ -144,17 +147,6 @@ class BufferAllocationLocator : public StmtExprMutator {
return result;
}

void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
Array<BufferRegion>* writes) const {
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
*reads = access[0];
*writes = access[1];
for (const auto& opaque_access : access[2]) {
reads->push_back(opaque_access);
writes->push_back(opaque_access);
}
}

/*! \brief The map from stmt to the buffers to be allocated under it. */
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm import tir, script
from tvm.ir import Range
Expand Down Expand Up @@ -81,6 +82,20 @@ def opaque_block_func() -> None:
B[i, j] = A[i, j] + 1.0


@tvm.script.tir
def opaque_access_func() -> None:
A = tir.alloc_buffer([1024])
B = tir.alloc_buffer([1024])
for i in tir.serial(0, 8):
with tir.block([8]) as [v]:
tir.bind(v, i)
tir.reads([A[v * 128 : v * 128 + 128]])
tir.writes([B[v * 128 : v * 128 + 128]])
tir.evaluate(
tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")
)


def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
Expand Down Expand Up @@ -110,6 +125,19 @@ def test_opaque_block():
tvm.ir.assert_structural_equal(block1.writes, ret[1])


def test_opaque_access():
block = opaque_access_func.body.block.body.body.block
alloc_buffers = opaque_access_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map)
ret1 = tir.analysis.get_block_access_region(block, buffer_var_map)
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(ret0[1], ret1[1])


def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
Expand Down Expand Up @@ -141,4 +169,5 @@ def test_match_buffer():
if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
test_match_buffer()
1 change: 0 additions & 1 deletion tests/python/unittest/test_tir_schedule_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# pylint: disable=missing-function-docstring,missing-module-docstring
import sys

import numpy as np
import pytest
import tvm
import tvm.testing
Expand Down

0 comments on commit 35ea46e

Please sign in to comment.