Skip to content

Commit

Permalink
[MetaSchedule] random compute location (apache#9940)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
7 people authored and yuanfz98 committed Jan 24, 2022
1 parent e34e9e6 commit 20a90ce
Show file tree
Hide file tree
Showing 19 changed files with 699 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) = 0;
/*!
* \brief Sample a compute-at location of the given block
* \param block_rv The block whose compute-at location is to be sampled
* \param decision The sampling decision
* \return The sampled loop where the input block is to be computed at
*/
virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
Optional<Integer> decision = NullOpt) = 0;

/******** Schedule: Get blocks & loops ********/
/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
"""
from .auto_inline import AutoInline
from .schedule_rule import PyScheduleRule, ScheduleRule
from .random_compute_location import RandomComputeLocation
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/random_compute_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Rule that randomly select a compute-at location for a free block"""
from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.RandomComputeLocation")
class RandomComputeLocation(ScheduleRule):
"""A rule that randomly select a compute-at location for a free block"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member
)
33 changes: 33 additions & 0 deletions python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from typing import List

from tvm.tir import Schedule
from tvm.tir.schedule import Trace


def check_trace(spaces: List[Schedule], expected: List[List[str]]):
expected_traces = {"\n".join(t) for t in expected}
actual_traces = set()
for space in spaces:
trace = Trace(space.trace.insts, {})
trace = trace.simplified(remove_postproc=True)
str_trace = "\n".join(str(trace).strip().splitlines())
actual_traces.add(str_trace)
assert str_trace in expected_traces, "\n" + str_trace
assert len(expected_traces) == len(actual_traces)
26 changes: 26 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,32 @@ def sample_perfect_tile(
)
)

@type_checked
def sample_compute_location(
self,
block: BlockRV,
decision: Optional[int] = None,
) -> LoopRV:
"""Sample a compute-at location of the given block
Parameters
----------
block : BlockRV
The block whose compute-at location is to be sampled
decision : Optional[int]
The sampling decision
Returns
-------
result : LoopRV
The sampled loop where the input block is to be computed at
"""
return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
self,
block,
decision,
)

########## Schedule: Get blocks & loops ##########
@type_checked
def get_block(
Expand Down
123 changes: 123 additions & 0 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class RandomComputeLocationNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
if (!CheckConditions(sch, block_rv)) {
return {sch};
}

// Step 1. If the producer of the input block needs a random compute-at location (specified by
// the annotation), we collect the producer first, and transform the producer block later.
// - The reason we collect the producer before transforming the input block is that, if the
// decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
// access the input block. Hence we collect its producer ahead of time.
// - Note that only single producer is allowed in this case.
Array<tir::BlockRV> producers{nullptr};
if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
true)) {
producers = sch->GetProducers(block_rv);
sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
ICHECK_EQ(producers.size(), 1);
}

// Step 2. Transform the input block.
tir::Schedule res = RandomlyComputeAt(sch, block_rv);

// Step 3. Transform the producer block if compute-location sampling is needed.
if (producers.defined()) {
res = RandomlyComputeAt(res, producers[0]);
}

return {res};
}

private:
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);

// Cond 1. The block is not the root block.
if (block_sref->parent == nullptr) {
return false;
}
// Cond 2. The block should be the direct child block of the root block.
if (GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false)
->parent != nullptr) {
return false;
}
// Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
// block.
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
if (loop_srefs.empty()) {
return false;
}
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
return false;
}
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
if (tir::HasBeenMultiLevelTiled(block_sref)) {
return false;
}
// Cond 6. The block has at lease one consumer.
if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
return false;
}
return true;
}

/*!
* \brief Keep sampling a compute-at location for the input block until success.
* \param sch The TIR schedule
* \param block_rv The block whose compute-at location is to be sampled
* \return The TIR schedule after transformation
*/
tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
sch->ComputeAt(block_rv, compute_at_loc, true);
return sch;
}

public:
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::RandomComputeLocation() {
return ScheduleRule(make_object<RandomComputeLocationNode>());
}

TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation")
.set_body_typed(ScheduleRule::RandomComputeLocation);
} // namespace meta_schedule
} // namespace tvm
33 changes: 33 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,39 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
* \param loop_sref The loop to be checked
* \return The IterVarType of the specific loop
*/
IterVarType GetLoopIterType(const StmtSRef& loop_sref);

/*!
* \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree
* \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried
* \return The lowest common ancestor of the input block srefs or loop srefs
* \note The input array is required to have at least one sref
*/
StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs);

/*!
* \brief Checks if the given block has been applied by multi-level tiling. We check this by
* examine the block's annotation.
* \param block_sref The block to be checked
* \return A boolean indicating whether the block has been multi-level tiled.
*/
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);

/*!
* \brief Collect all the feasible compute-at locations of the input block
* \param self The schedule state
* \param block_sref The block whose compute-at locations are to be collected
* \return All the feasible compute-at locations of the input block, given as an array of loop srefs
* and an array of their indices among the outer loops of the input block
*/
std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
const StmtSRef& block_sref);

/******** Producer-consumer relation ********/

/*!
Expand Down
Loading

0 comments on commit 20a90ce

Please sign in to comment.