From 27d42ab9446839eccf1473921408d0aa520709d4 Mon Sep 17 00:00:00 2001 From: huajsj Date: Mon, 13 Sep 2021 17:43:54 -0700 Subject: [PATCH] address review comments. --- python/tvm/contrib/pipeline_executor.py | 117 ++++++++++++------- src/runtime/pipeline/pipeline_executor.h | 6 +- tests/python/relay/test_pipeline_executor.py | 43 +++---- 3 files changed, 96 insertions(+), 70 deletions(-) diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 9759d80c94f66..d393892f57520 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -34,7 +34,7 @@ def pipeline_executor_enabled(): def build(pipe_configs): - """Use pipe_config to build and return module list and module dependency configuration. + """Use pipe_config to build and return PipelineExecutorFactoryModule. Parameters ---------- @@ -44,7 +44,8 @@ def build(pipe_configs): Returns ------- ret: PipelineExecutorFactoryModule - A class that wraps module list and module dependency configuration. + A factory class is responsible for receiving module and configuration + information and maintaining executor module """ mods = {} mod_n_configs = pipe_configs.get_config() @@ -85,7 +86,8 @@ def create(pipe_executor_factory_module): Parameters ---------- pipe_executor_factory_module : PipelineExecutorFactoryModule - It is wrapper class which include IRModule list and pipeline configuration. + A factory class is responsible for receiving module and configuration + information and maintaining executor module. Returns ------- @@ -116,7 +118,7 @@ class PipelineConfig(object): """ class Binding: - """This class define the module connection information. + """This class defines the module connection information. The type can only be either "input" or "output". Parameters @@ -128,8 +130,7 @@ class Binding: The type of this interface. It can only be either "input" or "output". name : str/integer - Name, for input it is string such as "data0", for output it is the - idx integer such as 0. + Name, for input it is string such as "data0", for output it is an integer such as 0. """ def __init__(self, owner, io_type, name, data_type=None): @@ -144,7 +145,7 @@ def __init__(self, owner, io_type, name, data_type=None): self.data_type = data_type def get_name(self): - """Return the interface name and name of owner who owns this interface.""" + # Return the interface name and name of owner who owns this interface. owner_name = "" if isinstance(self.io_owner, PipelineConfig.ModuleWrapper): owner_name = self.io_owner.name @@ -152,7 +153,7 @@ def get_name(self): return owner_name, self.name def get_owner_idx(self): - """Return owner idex if owner is ModuleWrapper, if not return 0.""" + # Return owner idex if owner is ModuleWrapper, if not return 0. if isinstance(self.io_owner, PipelineConfig.ModuleWrapper): return self.io_owner.idx @@ -161,15 +162,15 @@ def get_owner_idx(self): return 0 def is_global_interface(self): - """The global interface is the interface visible to the caller which use pipeline + """The global interface is the interface visible to the caller which use a pipeline executor, the global input interface is responsible for passing parameters to the internal module interface, and the global output interface is responsible for - outputting the pipeline executor compute results to caller. + outputting the results computed by the pipeline executor to a caller. """ return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper) def __repr__(self): - """Get all binding(input data) informations that looks like '|data_0: mod1:data_0'.""" + # Get all binding information. ret = " |{}: ".format(self.name) for binding in self.bindings: mname, dname = binding.get_name() @@ -177,7 +178,19 @@ def __repr__(self): return ret def check_dag_acyclic(self, start, inputs): - """It is to check whether the DAG that contains the inputs interfaces is acyclic.""" + """This is to check whether the DAG containing these input interfaces is acyclic. + Parameters + ---------- + start: ModuleWrapper + The starting node of the cycle check algorithm. + + inputs: Binding + These interfaces are used to connect to each other to build DAG. + + Return + ------ + Return True if there is no cycle in DAG. + """ for binding in inputs.values(): if start == binding.io_owner: return False @@ -188,12 +201,17 @@ def check_dag_acyclic(self, start, inputs): return True def connect(self, binding): - """Check whether the binding settings is correct or not. - correct connection are following - 1. global input to module input - 2. module output to global output - 3. module output to module input + """Connect the current interface to the destination interface. + correct connections as following 1. global input to module input, + 2. module output to global output, 3. module output to module input + + Parameters + ---------- + binding: Binding + The destination of this connection. """ + + # Check whether the binding setting is correct or not. if self.io_owner == binding.io_owner: raise RuntimeError(f"Can not bind itself.") @@ -222,7 +240,7 @@ def connect(self, binding): self.bindings.append(binding) if not self.is_global_interface(): - # check if the source and target data_type same + # Check whether the data types of the source and destination are the same. if ( isinstance(binding.io_owner, PipelineConfig.ModuleWrapper) and self.data_type != binding.data_type @@ -274,8 +292,8 @@ def __getitem__(self, key): return self.bindings[key] class ModuleWrapper: - """This class is a wrapper that represent the module, contains module informations, - binding informations and building information. + """This class is a wrapper that represents the module, contains module information, + binding information and building information. """ def __init__(self, mod=None): @@ -288,7 +306,7 @@ def __init__(self, mod=None): self.idx = None self.mod = mod self.input_params = InferType()(mod)["main"].params - self.output_values = InferType()(mod)["main"].checked_type.ret_type + self.output_type = InferType()(mod)["main"].checked_type.ret_type self.input_bindings = PipelineConfig.BindingList(self, "input") self.output_bindings = PipelineConfig.BindingList(self, "output") @@ -308,34 +326,45 @@ def __getitem__(self, key): raise RuntimeError(f"{key} not found!") - def get_data_type(self, key, interface_type): - """Get module interface data type.""" + def get_data_type(self, name, interface_type): + """Get module interface data type. + Parameters + ---------- + name: str + The interface name. + interface_type: + The interface type. + + Return + ------- + Return data type. + """ if interface_type == "input": for param in self.input_params: - if param.name_hint == key: + if param.name_hint == name: return param._checked_type_ if interface_type == "output": - if isinstance(self.output_values, tvm.ir.type.TupleType): - if int(key) < len(self.output_values.fields): - return self.output_values.fields[int(key)] + if isinstance(self.output_type, tvm.ir.type.TupleType): + if int(name) < len(self.output_type.fields): + return self.output_type.fields[int(key)] elif int(key) == 0: - return self.output_values + return self.output_type return None def set_idx_name(self, idx): - """Sepecify the index value and generate the module name.""" + # Set the index value and generate the module name. self.idx = idx self.name = "mod{}".format(str(idx)) def is_root_mod(self): - """Check whether it is root node, this function used by DAG topological sort.""" + # Check whether it is root node and is used by DAG topological sort. return all([not b.parents for b in self.input_bindings.bindings.values()]) def remove_self_from_bindings(self): """Remove itself from child reference to reduce child node in-degree. - This function used by DAG topological sort. + This function is used by DAG topological sort. """ for binding in self.output_bindings.bindings.values(): for child in binding.bindings: @@ -348,16 +377,17 @@ def __init__(self): self.output_bindings = self.BindingList(self, "output") def __str__(self): - """ Get configuration as string""" - # topological sort to get correct module order in list. + # Get configuration as string. + + # Use topological sort to get correct module order. self.dag_topology_sort() - # get input + # Get input. input_dump = "Inputs\n" for input_name in self.input_bindings.bindings: inf = self.input_bindings.bindings[input_name] input_dump += str(inf) + "\n" - # get connections. + # Get connections. output = {} connections_dump = "\nconnections\n" for mod in self.mod_wrapper: @@ -373,7 +403,7 @@ def __str__(self): else: output[dep_dname] = f"{mname}.output({dname})" - # get output + # Get output. output_dump = "\noutput\n" for name in sorted(output.keys()): output_dump += f" |output({name}) : {output[name]}\n" @@ -395,7 +425,7 @@ def __getitem__(self, key): raise RuntimeError(f"{key} not found.") def get_config(self): - """ Get configuration information in dictionary form.""" + """Get configuration information in dictionary form.""" # Use topological sort to get correct order of modules. self.dag_topology_sort() @@ -416,7 +446,7 @@ def get_config(self): dep_item["input_name"] = dname dep_conf.append(dep_item) - # Ouput_idx start from 0. + # The ouput_idx start from 0. output["output_idx"] = int(binding.name) output["dependent"] = dep_conf output_conf.append(output) @@ -438,7 +468,7 @@ def get_config(self): return mconfig def dag_topology_sort(self): - """ Use topological sort to get order of pipeline modules.""" + """Use topological sort to get order of pipeline modules.""" mlist = [] mod_wrapper = self.mod_wrapper.copy() while mod_wrapper: @@ -457,21 +487,22 @@ def dag_topology_sort(self): self.mod_wrapper[mod].set_idx_name(i + 1) def get_mod_idx(self, mod): - """Return module index.""" + # Return module index. idx = self.mod_wrapper[mod].idx return idx def pipe_input(self, name): - """Return the corresponding input binding interface accordding to the name.""" + # Return the corresponding input binding interface according to the name. return self.input_bindings[name] def pipe_output(self, idx): - """Return the corresponding output binding interface accordding to the name.""" + # Return the corresponding output binding interface according to the name. return self.output_bindings[idx] class PipelineExecutorFactoryModule(object): - """Common interface for pipeline executor factory modules. + """A factory class is responsible for receiving module and configuration + information and maintaining executor module Parameters ---------- diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index d0d3fb87299ac..f42f2e916c3c3 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -30,9 +30,9 @@ namespace tvm { namespace runtime { /*! * \brief pipeline executor. - * This executor class use module list and dependency relations of modules as - * the parameters and executes these modules on heterogeneous targets in pipeline - * parallel to improve throughput. + * This executor class use a module list and dependency relations of modules as + * the parameters and executes these modules on heterogeneous targets in a pipeline + * parallel manner to improve throughput. * * This executor can be accessed by various language via * TVM runtime PackedFunc API. diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 4d2fefec009dd..18986fa1e6eae 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -25,7 +25,7 @@ def get_mannual_mod(): - """Get list of module that represent a subgraph.""" + # Get list of module that represent a subgraph. mods = [] dshape = (3, 3) data = relay.var("data_0", relay.TensorType(dshape, "float32")) @@ -40,7 +40,7 @@ def get_mannual_mod(): mv2 = relay.Constant(tvm.nd.array(mvalue2)) mv3 = relay.Constant(tvm.nd.array(mvalue3)) - """The first model has three output.""" + # The first model has three output. net1_output1 = relay.add(data, mv1) net1_output2 = relay.subtract(data, mv2) @@ -74,7 +74,7 @@ def get_mannual_mod(): def get_manual_conf(mods, target): - """This function is used to generate manual pipeline configuration.""" + # This function is used to generate manual pipeline configuration. mod_config = {} """The third output is the final output, the second output is for mod3, the first is for mod2 input. @@ -130,13 +130,12 @@ def get_manual_conf(mods, target): def test_pipe_config_check(): - """This function is used to trigger runtime error by appling wrong logic connection.""" + # This function is used to trigger runtime error by applying wrong logic connection. - """Get three pipeline modules here. - """ + # Get three pipeline modules here. (mod1, mod2, mod3), dshape = get_mannual_mod() - """The input/output name is illegal and expects a runtime error. - """ + + # The input/output name is illegal and expects a runtime error. pipe_error = pipeline_executor.PipelineConfig() with pytest.raises(RuntimeError): pipe_error[mod1]["output"][9] @@ -144,14 +143,12 @@ def test_pipe_config_check(): with pytest.raises(RuntimeError): pipe_error[mod1]["input"]["data_9"] - """The connection will cause a circle in DAG and exepects runtime error. - """ + # The connection will cause a cycle in DAG and expects runtime error. with pytest.raises(RuntimeError): pipe_error[mod1]["output"][0].connect(pipe_error[mod2]["input"]["data_0"]) pipe_error[mod2]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) - """The module connection is illegal and expects runtime error. - """ + # The module connection is illegal and expects runtime error. with pytest.raises(RuntimeError): pipe_error[mod1]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) @@ -186,8 +183,9 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() - # The global input named "data_0" will be connected to a input - # named "data_0" of mod1. + """ The global input named "data_0" will be connected to a input + named "data_0" of mod1. + """ pipe_config["input"]["data_0"].connect(pipe_config[mod1]["input"]["data_0"]) # The global Input named "data_1" will be connected to a input named "data_1" of mod2. @@ -196,16 +194,16 @@ def test_pipeline(): # The mod1 output[0] will be connected to a input named "data_0" of mod2. pipe_config[mod1]["output"][0].connect(pipe_config[mod2]["input"]["data_0"]) - # Mod1 output[1] will be connected to a input named "data_0" of mod3. + # The mod1 output[1] will be connected to a input named "data_0" of mod3. pipe_config[mod1]["output"][1].connect(pipe_config[mod3]["input"]["data_0"]) - # Mod2 output[2] will be connected to a input named "data_1" of mod3. + # The mod2 output[2] will be connected to a input named "data_1" of mod3. pipe_config[mod2]["output"][0].connect(pipe_config[mod3]["input"]["data_1"]) - # Mod1 output[2] will be connected to global output[1]. + # The mod1 output[2] will be connected to global output[1]. pipe_config[mod1]["output"][2].connect(pipe_config["output"]["0"]) - # Mod3 output[0] will be connected to global output[2]. + # The mod3 output[0] will be connected to global output[2]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) """Print configueration (print(pipe_config)), the result looks like following. @@ -223,8 +221,7 @@ def test_pipeline(): |mod2.output(0)-> mod3.data_1 """ - """Set other parameter. - """ + # Set other parameter. pipe_config[mod1].target = target[0] pipe_config[mod1].dev = target[1] @@ -234,12 +231,10 @@ def test_pipeline(): pipe_config[mod3].target = "llvm" pipe_config[mod3].dev = tvm.cpu(0) - """Here is to check correctness for configuration generated by API. - """ + # Here is to check correctness for configuration generated by API. assert pipe_config.get_config() == get_manual_conf([mod1, mod2, mod3], target) - """Build and create pipeline module. - """ + # Build and create pipeline module. with tvm.transform.PassContext(opt_level=3): pipeline_mod_config = pipeline_executor.build(pipe_config)