diff --git a/compiler/definitions/ir/dfg_node.py b/compiler/definitions/ir/dfg_node.py index e3e56631c..587622680 100644 --- a/compiler/definitions/ir/dfg_node.py +++ b/compiler/definitions/ir/dfg_node.py @@ -312,6 +312,14 @@ def get_option_implemented_round_robin_parallelizer(self): return parallelizer return None + def get_option_implemented_round_robin_with_unwrap_parallelizer(self): + for parallelizer in self.parallelizer_list: + splitter = parallelizer.get_splitter() + if splitter.is_splitter_round_robin_with_unwrap_flag() and parallelizer.are_all_parts_implemented(): + return parallelizer + return None + + def get_option_implemented_consecutive_chunks_parallelizer(self): for parallelizer in self.parallelizer_list: splitter = parallelizer.get_splitter() diff --git a/compiler/definitions/ir/nodes/r_split.py b/compiler/definitions/ir/nodes/r_split.py index 68a889f2f..011df0559 100644 --- a/compiler/definitions/ir/nodes/r_split.py +++ b/compiler/definitions/ir/nodes/r_split.py @@ -1,7 +1,7 @@ import os from datatypes_new.AccessKind import AccessKind -from datatypes_new.BasicDatatypes import Operand +from datatypes_new.BasicDatatypes import Operand, Flag from datatypes_new.CommandInvocationWithIOVars import CommandInvocationWithIOVars import config @@ -24,21 +24,8 @@ def __init__(self, parallelizer_list=parallelizer_list, cmd_related_properties=cmd_related_properties) - ## TODO: Generalize this code (for this and SortGReduce) to be able to add an option to any command. def add_r_flag(self): - assert(False) - assert(len(self.com_options) <= 1) - - ## Add -r in r_split - new_opt = (0, Arg(string_to_argument("-r"))) - shifted_options = [(i+1, opt) for i, opt in self.com_options] - self.com_options = [new_opt] + shifted_options - - ## This is not a proper option check. It just works if the r_flag is added as a separate option. - def has_r_flag(self): - assert(False) - option_strings = [str(opt) for i, opt in self.com_options] - return ("-r" in option_strings) + self.cmd_invocation_with_io_vars.flag_option_list.append(Flag("-r")) def make_r_split(input_id, out_ids, r_split_batch_size): @@ -56,3 +43,8 @@ def make_r_split(input_id, out_ids, r_split_batch_size): implicit_use_of_streaming_output=None, access_map=access_map) return RSplit(cmd_inv_with_io_vars) + +def make_r_split_with_unwrap_flag(input_id, out_ids, r_split_batch_size): + standard_r_split = make_r_split(input_id, out_ids, r_split_batch_size) + standard_r_split.add_r_flag() + return standard_r_split diff --git a/compiler/ir.py b/compiler/ir.py index c2f36543e..e5cd5c423 100644 --- a/compiler/ir.py +++ b/compiler/ir.py @@ -771,6 +771,9 @@ def apply_parallelization_to_node(self, node_id, parallelizer, fileIdGen, fan_ou # TODO: for both functions, check which parameters are needed self.apply_round_robin_parallelization_to_node(node_id, parallelizer, fileIdGen, fan_out, batch_size, no_cat_split_vanish, r_split_batch_size) + elif splitter.is_splitter_round_robin_with_unwrap_flag(): + self.apply_round_robin_with_unwrap_flag_parallelization_to_node(node_id, parallelizer, fileIdGen, fan_out, + batch_size, no_cat_split_vanish, r_split_batch_size) elif splitter.is_splitter_consec_chunks(): self.apply_consecutive_chunks_parallelization_to_node(node_id, parallelizer, fileIdGen, fan_out, batch_size, no_cat_split_vanish, r_split_batch_size) @@ -819,6 +822,45 @@ def apply_round_robin_parallelization_to_node(self, node_id, parallelizer, fileI # aggregator self.introduce_aggregator_for_round_robin(out_mapper_ids, parallelizer, streaming_output) + def apply_round_robin_with_unwrap_flag_parallelization_to_node(self, node_id, parallelizer, fileIdGen, fan_out, + batch_size, no_cat_split_vanish, r_split_batch_size): + # round robin with unwrap flag is an inferred parallelizer which ensures that + # the command is commutative and has an aggregator for consecutive chunks; + # thus we can check whether we can re-open a previous "RR"-parallelization ending with `r_merge` + node = self.get_node(node_id) + streaming_input, streaming_output, configuration_inputs = \ + node.get_single_streaming_input_single_output_and_configuration_inputs_of_node_for_parallelization() + original_cmd_invocation_with_io_vars = node.cmd_invocation_with_io_vars + + prev_nodes = self.get_previous_nodes(node_id) + first_pred_node, first_pred_cmd_inv = self.get_first_previous_node_and_first_previous_cmd_invocation(prev_nodes) + + # remove node to be parallelized + self.remove_node(node_id) # remove it here already as as we need to remove edge end points ow. to avoid disconnecting graph to avoid disconnecting graph + + if len(prev_nodes) == 1 and isinstance(first_pred_node, r_merge.RMerge): + # and node.is_commutative(): implied by how this kind of splitter is inferred + self.remove_node(prev_nodes[0]) # also sets respective edge to's and from's to None + + in_unwrap_ids = first_pred_cmd_inv.operand_list + out_unwrap_ids = self.introduce_unwraps(fileIdGen, in_unwrap_ids) + in_mapper_ids = out_unwrap_ids + else: + # splitter + round_robin_with_unwrap_flag_splitter_generator = lambda input_id, output_ids: r_split.make_r_split_with_unwrap_flag(input_id, output_ids, r_split_batch_size) + out_split_ids = self.introduce_splitter(round_robin_with_unwrap_flag_splitter_generator, fan_out, fileIdGen, streaming_input) + in_mapper_ids = out_split_ids + + # mappers + out_mapper_ids = self.introduce_mappers(fan_out, fileIdGen, in_mapper_ids, original_cmd_invocation_with_io_vars, + parallelizer) + + in_aggregator_ids = out_mapper_ids + out_aggregator_id = streaming_output + self.introduce_aggregators_for_consec_chunks(fileIdGen, in_aggregator_ids, + original_cmd_invocation_with_io_vars, out_aggregator_id, parallelizer, + streaming_output) + def apply_consecutive_chunks_parallelization_to_node(self, node_id, parallelizer, fileIdGen, fan_out, batch_size, no_cat_split_vanish, r_split_batch_size): # check whether we can fuse with previous node's parallelization: @@ -841,12 +883,6 @@ def apply_consecutive_chunks_parallelization_to_node(self, node_id, parallelizer # can be fused self.remove_node(prev_nodes[0]) # also sets respective edge to's and from's to None in_mapper_ids = first_pred_cmd_inv.operand_list - elif len(prev_nodes) == 1 and isinstance(first_pred_node, r_merge.RMerge) and node.is_commutative(): - self.remove_node(prev_nodes[0]) # also sets respective edge to's and from's to None - - in_unwrap_ids = first_pred_cmd_inv.operand_list - out_unwrap_ids = self.introduce_unwraps(fileIdGen, in_unwrap_ids) - in_mapper_ids = out_unwrap_ids else: # cannot be fused so introduce splitter # splitter consec_chunks_splitter_generator = lambda input_id, output_ids: pash_split.make_split_file(input_id, diff --git a/compiler/pash_runtime.py b/compiler/pash_runtime.py index d5ec78bca..27217a971 100644 --- a/compiler/pash_runtime.py +++ b/compiler/pash_runtime.py @@ -287,9 +287,11 @@ def choose_parallelizing_transformation(curr_id, graph, r_split_flag): # shall r curr = graph.get_node(curr_id) # we ignore `r_split_flag` here as we want to exploit r_merge followed by commutative command # which only works if the a parallelizer for the latter is chosen (sort does not have RR-parallelizer) - # we prioritize round robin over consecutive chunks: - return return_default_if_none_else_itself(curr.get_option_implemented_round_robin_parallelizer(), - curr.get_option_implemented_consecutive_chunks_parallelizer()) + # we prioritize round robin over round robin with unwrap over consecutive chunks: + list_all_parallelizers_in_priority = [curr.get_option_implemented_round_robin_parallelizer(), + curr.get_option_implemented_round_robin_with_unwrap_parallelizer(), + curr.get_option_implemented_consecutive_chunks_parallelizer()] + return next((item for item in list_all_parallelizers_in_priority if item is not None), None) # When `r_split_flag` should be used: # if r_split_flag: # option_parallelizer = curr.get_option_implemented_round_robin_parallelizer()