Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M4a] Schedule Rule: Random-Compute-Location #9940

Merged
merged 1 commit into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) = 0;
/*!
* \brief Sample a compute-at location of the given block
* \param block_rv The block whose compute-at location is to be sampled
* \param decision The sampling decision
* \return The sampled loop where the input block is to be computed at
*/
virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
Optional<Integer> decision = NullOpt) = 0;

/******** Schedule: Get blocks & loops ********/
/*!
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
blocks in a schedule. See also PostOrderApply.
"""
from .schedule_rule import PyScheduleRule, ScheduleRule
from .random_compute_location import RandomComputeLocation
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/random_compute_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Rule that randomly select a compute-at location for a free block"""
from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.RandomComputeLocation")
class RandomComputeLocation(ScheduleRule):
"""A rule that randomly select a compute-at location for a free block"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member
)
33 changes: 33 additions & 0 deletions python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from typing import List

from tvm.tir import Schedule
from tvm.tir.schedule import Trace


def check_trace(spaces: List[Schedule], expected: List[List[str]]):
expected_traces = {"\n".join(t) for t in expected}
actual_traces = set()
for space in spaces:
trace = Trace(space.trace.insts, {})
trace = trace.simplified(remove_postproc=True)
str_trace = "\n".join(str(trace).strip().splitlines())
actual_traces.add(str_trace)
assert str_trace in expected_traces, "\n" + str_trace
assert len(expected_traces) == len(actual_traces)
26 changes: 26 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,32 @@ def sample_perfect_tile(
)
)

@type_checked
def sample_compute_location(
self,
block: BlockRV,
decision: Optional[int] = None,
) -> LoopRV:
"""Sample a compute-at location of the given block

Parameters
----------
block : BlockRV
The block whose compute-at location is to be sampled
decision : Optional[int]
The sampling decision

Returns
-------
result : LoopRV
The sampled loop where the input block is to be computed at
"""
return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
self,
block,
decision,
)

########## Schedule: Get blocks & loops ##########
@type_checked
def get_block(
Expand Down
123 changes: 123 additions & 0 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class RandomComputeLocationNode : public ScheduleRuleNode {
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
if (!CheckConditions(sch, block_rv)) {
return {sch};
}

// Step 1. If the producer of the input block needs a random compute-at location (specified by
// the annotation), we collect the producer first, and transform the producer block later.
// - The reason we collect the producer before transforming the input block is that, if the
// decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
// access the input block. Hence we collect its producer ahead of time.
// - Note that only single producer is allowed in this case.
Array<tir::BlockRV> producers{nullptr};
if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
true)) {
producers = sch->GetProducers(block_rv);
sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
ICHECK_EQ(producers.size(), 1);
}

// Step 2. Transform the input block.
tir::Schedule res = RandomlyComputeAt(sch, block_rv);

// Step 3. Transform the producer block if compute-location sampling is needed.
if (producers.defined()) {
res = RandomlyComputeAt(res, producers[0]);
}
junrushao marked this conversation as resolved.
Show resolved Hide resolved

return {res};
}

private:
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);

// Cond 1. The block is not the root block.
if (block_sref->parent == nullptr) {
return false;
}
// Cond 2. The block should be the direct child block of the root block.
if (GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false)
->parent != nullptr) {
return false;
}
// Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
// block.
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
if (loop_srefs.empty()) {
return false;
}
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
return false;
}
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
if (tir::HasBeenMultiLevelTiled(block_sref)) {
return false;
}
// Cond 6. The block has at lease one consumer.
if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
return false;
}
return true;
}

/*!
* \brief Keep sampling a compute-at location for the input block until success.
* \param sch The TIR schedule
* \param block_rv The block whose compute-at location is to be sampled
* \return The TIR schedule after transformation
*/
tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
sch->ComputeAt(block_rv, compute_at_loc, true);
return sch;
}

public:
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::RandomComputeLocation() {
return ScheduleRule(make_object<RandomComputeLocationNode>());
}

TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation")
.set_body_typed(ScheduleRule::RandomComputeLocation);
} // namespace meta_schedule
} // namespace tvm
33 changes: 33 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,39 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
* \param loop_sref The loop to be checked
* \return The IterVarType of the specific loop
*/
IterVarType GetLoopIterType(const StmtSRef& loop_sref);

/*!
* \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree
* \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried
* \return The lowest common ancestor of the input block srefs or loop srefs
* \note The input array is required to have at least one sref
*/
StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs);

/*!
* \brief Checks if the given block has been applied by multi-level tiling. We check this by
* examine the block's annotation.
* \param block_sref The block to be checked
* \return A boolean indicating whether the block has been multi-level tiled.
*/
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);

/*!
* \brief Collect all the feasible compute-at locations of the input block
* \param self The schedule state
* \param block_sref The block whose compute-at locations are to be collected
* \return All the feasible compute-at locations of the input block, given as an array of loop srefs
* and an array of their indices among the outer loops of the input block
*/
std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
const StmtSRef& block_sref);

/******** Producer-consumer relation ********/

/*!
Expand Down
Loading