Skip to content

Commit

Permalink
address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
huajsj committed Sep 14, 2021
1 parent 73b6c98 commit 27d42ab
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 70 deletions.
117 changes: 74 additions & 43 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pipeline_executor_enabled():


def build(pipe_configs):
"""Use pipe_config to build and return module list and module dependency configuration.
"""Use pipe_config to build and return PipelineExecutorFactoryModule.
Parameters
----------
Expand All @@ -44,7 +44,8 @@ def build(pipe_configs):
Returns
-------
ret: PipelineExecutorFactoryModule
A class that wraps module list and module dependency configuration.
A factory class is responsible for receiving module and configuration
information and maintaining executor module
"""
mods = {}
mod_n_configs = pipe_configs.get_config()
Expand Down Expand Up @@ -85,7 +86,8 @@ def create(pipe_executor_factory_module):
Parameters
----------
pipe_executor_factory_module : PipelineExecutorFactoryModule
It is wrapper class which include IRModule list and pipeline configuration.
A factory class is responsible for receiving module and configuration
information and maintaining executor module.
Returns
-------
Expand Down Expand Up @@ -116,7 +118,7 @@ class PipelineConfig(object):
"""

class Binding:
"""This class define the module connection information.
"""This class defines the module connection information.
The type can only be either "input" or "output".
Parameters
Expand All @@ -128,8 +130,7 @@ class Binding:
The type of this interface. It can only be either "input" or "output".
name : str/integer
Name, for input it is string such as "data0", for output it is the
idx integer such as 0.
Name, for input it is string such as "data0", for output it is an integer such as 0.
"""

def __init__(self, owner, io_type, name, data_type=None):
Expand All @@ -144,15 +145,15 @@ def __init__(self, owner, io_type, name, data_type=None):
self.data_type = data_type

def get_name(self):
"""Return the interface name and name of owner who owns this interface."""
# Return the interface name and name of owner who owns this interface.
owner_name = ""
if isinstance(self.io_owner, PipelineConfig.ModuleWrapper):
owner_name = self.io_owner.name

return owner_name, self.name

def get_owner_idx(self):
"""Return owner idex if owner is ModuleWrapper, if not return 0."""
# Return owner idex if owner is ModuleWrapper, if not return 0.
if isinstance(self.io_owner, PipelineConfig.ModuleWrapper):
return self.io_owner.idx

Expand All @@ -161,23 +162,35 @@ def get_owner_idx(self):
return 0

def is_global_interface(self):
"""The global interface is the interface visible to the caller which use pipeline
"""The global interface is the interface visible to the caller which use a pipeline
executor, the global input interface is responsible for passing parameters to the
internal module interface, and the global output interface is responsible for
outputting the pipeline executor compute results to caller.
outputting the results computed by the pipeline executor to a caller.
"""
return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper)

def __repr__(self):
"""Get all binding(input data) informations that looks like '|data_0: mod1:data_0'."""
# Get all binding information.
ret = " |{}: ".format(self.name)
for binding in self.bindings:
mname, dname = binding.get_name()
ret += "{0}:{1} ".format(mname, dname)
return ret

def check_dag_acyclic(self, start, inputs):
"""It is to check whether the DAG that contains the inputs interfaces is acyclic."""
"""This is to check whether the DAG containing these input interfaces is acyclic.
Parameters
----------
start: ModuleWrapper
The starting node of the cycle check algorithm.
inputs: Binding
These interfaces are used to connect to each other to build DAG.
Return
------
Return True if there is no cycle in DAG.
"""
for binding in inputs.values():
if start == binding.io_owner:
return False
Expand All @@ -188,12 +201,17 @@ def check_dag_acyclic(self, start, inputs):
return True

def connect(self, binding):
"""Check whether the binding settings is correct or not.
correct connection are following
1. global input to module input
2. module output to global output
3. module output to module input
"""Connect the current interface to the destination interface.
correct connections as following 1. global input to module input,
2. module output to global output, 3. module output to module input
Parameters
----------
binding: Binding
The destination of this connection.
"""

# Check whether the binding setting is correct or not.
if self.io_owner == binding.io_owner:
raise RuntimeError(f"Can not bind itself.")

Expand Down Expand Up @@ -222,7 +240,7 @@ def connect(self, binding):

self.bindings.append(binding)
if not self.is_global_interface():
# check if the source and target data_type same
# Check whether the data types of the source and destination are the same.
if (
isinstance(binding.io_owner, PipelineConfig.ModuleWrapper)
and self.data_type != binding.data_type
Expand Down Expand Up @@ -274,8 +292,8 @@ def __getitem__(self, key):
return self.bindings[key]

class ModuleWrapper:
"""This class is a wrapper that represent the module, contains module informations,
binding informations and building information.
"""This class is a wrapper that represents the module, contains module information,
binding information and building information.
"""

def __init__(self, mod=None):
Expand All @@ -288,7 +306,7 @@ def __init__(self, mod=None):
self.idx = None
self.mod = mod
self.input_params = InferType()(mod)["main"].params
self.output_values = InferType()(mod)["main"].checked_type.ret_type
self.output_type = InferType()(mod)["main"].checked_type.ret_type
self.input_bindings = PipelineConfig.BindingList(self, "input")
self.output_bindings = PipelineConfig.BindingList(self, "output")

Expand All @@ -308,34 +326,45 @@ def __getitem__(self, key):

raise RuntimeError(f"{key} not found!")

def get_data_type(self, key, interface_type):
"""Get module interface data type."""
def get_data_type(self, name, interface_type):
"""Get module interface data type.
Parameters
----------
name: str
The interface name.
interface_type:
The interface type.
Return
-------
Return data type.
"""
if interface_type == "input":
for param in self.input_params:
if param.name_hint == key:
if param.name_hint == name:
return param._checked_type_

if interface_type == "output":
if isinstance(self.output_values, tvm.ir.type.TupleType):
if int(key) < len(self.output_values.fields):
return self.output_values.fields[int(key)]
if isinstance(self.output_type, tvm.ir.type.TupleType):
if int(name) < len(self.output_type.fields):
return self.output_type.fields[int(key)]
elif int(key) == 0:
return self.output_values
return self.output_type

return None

def set_idx_name(self, idx):
"""Sepecify the index value and generate the module name."""
# Set the index value and generate the module name.
self.idx = idx
self.name = "mod{}".format(str(idx))

def is_root_mod(self):
"""Check whether it is root node, this function used by DAG topological sort."""
# Check whether it is root node and is used by DAG topological sort.
return all([not b.parents for b in self.input_bindings.bindings.values()])

def remove_self_from_bindings(self):
"""Remove itself from child reference to reduce child node in-degree.
This function used by DAG topological sort.
This function is used by DAG topological sort.
"""
for binding in self.output_bindings.bindings.values():
for child in binding.bindings:
Expand All @@ -348,16 +377,17 @@ def __init__(self):
self.output_bindings = self.BindingList(self, "output")

def __str__(self):
""" Get configuration as string"""
# topological sort to get correct module order in list.
# Get configuration as string.

# Use topological sort to get correct module order.
self.dag_topology_sort()
# get input
# Get input.
input_dump = "Inputs\n"
for input_name in self.input_bindings.bindings:
inf = self.input_bindings.bindings[input_name]
input_dump += str(inf) + "\n"

# get connections.
# Get connections.
output = {}
connections_dump = "\nconnections\n"
for mod in self.mod_wrapper:
Expand All @@ -373,7 +403,7 @@ def __str__(self):
else:
output[dep_dname] = f"{mname}.output({dname})"

# get output
# Get output.
output_dump = "\noutput\n"
for name in sorted(output.keys()):
output_dump += f" |output({name}) : {output[name]}\n"
Expand All @@ -395,7 +425,7 @@ def __getitem__(self, key):
raise RuntimeError(f"{key} not found.")

def get_config(self):
""" Get configuration information in dictionary form."""
"""Get configuration information in dictionary form."""

# Use topological sort to get correct order of modules.
self.dag_topology_sort()
Expand All @@ -416,7 +446,7 @@ def get_config(self):
dep_item["input_name"] = dname
dep_conf.append(dep_item)

# Ouput_idx start from 0.
# The ouput_idx start from 0.
output["output_idx"] = int(binding.name)
output["dependent"] = dep_conf
output_conf.append(output)
Expand All @@ -438,7 +468,7 @@ def get_config(self):
return mconfig

def dag_topology_sort(self):
""" Use topological sort to get order of pipeline modules."""
"""Use topological sort to get order of pipeline modules."""
mlist = []
mod_wrapper = self.mod_wrapper.copy()
while mod_wrapper:
Expand All @@ -457,21 +487,22 @@ def dag_topology_sort(self):
self.mod_wrapper[mod].set_idx_name(i + 1)

def get_mod_idx(self, mod):
"""Return module index."""
# Return module index.
idx = self.mod_wrapper[mod].idx
return idx

def pipe_input(self, name):
"""Return the corresponding input binding interface accordding to the name."""
# Return the corresponding input binding interface according to the name.
return self.input_bindings[name]

def pipe_output(self, idx):
"""Return the corresponding output binding interface accordding to the name."""
# Return the corresponding output binding interface according to the name.
return self.output_bindings[idx]


class PipelineExecutorFactoryModule(object):
"""Common interface for pipeline executor factory modules.
"""A factory class is responsible for receiving module and configuration
information and maintaining executor module
Parameters
----------
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace tvm {
namespace runtime {
/*!
* \brief pipeline executor.
* This executor class use module list and dependency relations of modules as
* the parameters and executes these modules on heterogeneous targets in pipeline
* parallel to improve throughput.
* This executor class use a module list and dependency relations of modules as
* the parameters and executes these modules on heterogeneous targets in a pipeline
* parallel manner to improve throughput.
*
* This executor can be accessed by various language via
* TVM runtime PackedFunc API.
Expand Down
Loading

0 comments on commit 27d42ab

Please sign in to comment.