forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MetaSchedule][M4a] Schedule Rule: Multi-Level-Tiling (apache#10043)
* multi level tiling * remove tensor core related code * pylint * fix Co-authored-by: Junru Shao <junrushao1994@gmail.com>
- Loading branch information
Showing
10 changed files
with
898 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Multi-level tiling with reuse.""" | ||
from typing import Any, Dict, List, NamedTuple, Optional | ||
|
||
from tvm._ffi import register_object | ||
|
||
from .. import _ffi_api | ||
from .schedule_rule import ScheduleRule | ||
|
||
|
||
class ReuseType(NamedTuple): | ||
"""Reuse type.""" | ||
|
||
req: str | ||
levels: List[int] | ||
scope: str | ||
|
||
def as_dict(self) -> Dict[str, Any]: | ||
"""Return the dict representation of the reuse type.""" | ||
return { | ||
"req": self.req, | ||
"levels": self.levels, | ||
"scope": self.scope, | ||
} | ||
|
||
|
||
@register_object("meta_schedule.MultiLevelTiling") | ||
class MultiLevelTiling(ScheduleRule): | ||
"""Multi-level tiling with reuse. | ||
Parameters | ||
---------- | ||
structure : str | ||
The tiling structure. Recommended: | ||
- 'SSRSRS' on CPU | ||
- 'SSSRRSRS' on GPU | ||
tile_bind : Optional[List[str]] | ||
For each level of tiles, which thread axis it is bound to. Recommended: | ||
- None on CPU | ||
- [blockIdx.x, vthread.x, threadIdx.x] on GPU | ||
max_innermost_factor : Optional[int] | ||
The maximum size of the innermost factor. None means no limit | ||
vector_load_lens : Optional[List[int]] | ||
The length of vector lane in vectorized cooperative fetching. | ||
None means disable vectorization | ||
reuse_read : Optional[ReuseType] | ||
Data reuse configuration for reading. None means no reuse. | ||
reuse_write : Optional[ReuseType] | ||
Data reuse configuration for writing. None means no reuse. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
structure: str, | ||
tile_binds: Optional[List[str]] = None, | ||
max_innermost_factor: Optional[int] = None, | ||
vector_load_lens: Optional[List[int]] = None, | ||
reuse_read: Optional[ReuseType] = None, | ||
reuse_write: Optional[ReuseType] = None, | ||
) -> None: | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member | ||
structure, | ||
tile_binds, | ||
max_innermost_factor, | ||
vector_load_lens, | ||
reuse_read.as_dict() if reuse_read is not None else None, | ||
reuse_write.as_dict() if reuse_write is not None else None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.