From a3200879ba2d9a9e6728b49a170bb39d7d11e88c Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Thu, 31 Oct 2024 17:12:23 -0400 Subject: [PATCH] llvm/scheduler: Drop tracking of the number of node executions Not used anywhere. Signed-off-by: Jan Vesely --- psyneulink/core/llvm/codegen.py | 2 +- psyneulink/core/llvm/helpers.py | 51 ++++++++++++++------------------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index 0ef30af9ff..e340bfd4fb 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -974,7 +974,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): args.append(cond) builder.call(node_f, args) - cond_gen.generate_update_after_run(builder, cond, node) + cond_gen.generate_update_after_node_execution(builder, cond, node) builder.block.name = "post_invoke_" + node_f.name diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index 7d4ef3836e..e5dc9f612d 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -469,19 +469,12 @@ def get_private_condition_struct_type(self, composition): assert len(time_stamp_struct) == len(self.TimeIndex) - status_struct = ir.LiteralStructType([ - self.ctx.int32_ty, # number of executions in this run - time_stamp_struct # time stamp of last execution - ]) - structure = ir.LiteralStructType([ - time_stamp_struct, # current time stamp - ir.ArrayType(status_struct, len(composition.nodes)) # for each node - ]) - return structure + nodes_time_stamps_array = ir.ArrayType(time_stamp_struct, len(composition.nodes)) + + return ir.LiteralStructType((time_stamp_struct, nodes_time_stamps_array)) def get_private_condition_initializer(self, composition): - return ((0, 0, 0), - tuple((0, (-1, -1, -1)) for _ in composition.nodes)) + return ((0, 0, 0), tuple((-1, -1, -1) for _ in composition.nodes)) def get_condition_struct_type(self, node=None): node = self.composition if node is None else node @@ -507,14 +500,14 @@ def bump_ts(self, builder, cond_ptr, count=(0, 0, 1)): """ Increments the time structure of the composition. Count should be a tuple where there is a number in only one spot, and zeroes elsewhere. - Indices greater than that of the one are zeroed. + Indices greater than the incremented one are zeroed. """ # Only one element should be non-zero assert count.count(0) == len(count) - 1 # Get timestruct pointer - ts_ptr = builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) + ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr) ts = builder.load(ts_ptr) assert len(ts.type) == len(count) @@ -556,13 +549,20 @@ def ts_compare(self, builder, ts1, ts2, comp): return result - def __get_node_status_ptr(self, builder, cond_ptr, node): + def __get_global_ts_ptr(self, builder, cond_ptr): + # derefence the structure, the first element (private structure), + # and the first element of the private strucutre is the global ts. + return builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) + + def __get_node_ts_ptr(self, builder, cond_ptr, node): node_idx = self.ctx.int32_ty(self.composition.nodes.index(node)) + + # derefence the structure, the first element (private structure), the + # second element is the node time stamp array, use index in the array return builder.gep(cond_ptr, [self._zero, self._zero, self.ctx.int32_ty(1), node_idx]) def __get_node_ts(self, builder, cond_ptr, node): - status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node) - ts_ptr = builder.gep(status_ptr, [self.ctx.int32_ty(0), self.ctx.int32_ty(1)]) + ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node) return builder.load(ts_ptr) def get_global_ts(self, builder, cond_ptr): @@ -582,20 +582,13 @@ def get_global_pass(self, builder, cond_ptr): def get_global_step(self, builder, cond_ptr): return self._extract_global_time(builder, cond_ptr, self.TimeIndex.STEP) - def generate_update_after_run(self, builder, cond_ptr, node): - status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node) - status = builder.load(status_ptr) - - # Update total number of runs - runs = builder.extract_value(status, 0) - runs = builder.add(runs, runs.type(1)) - status = builder.insert_value(status, runs, 0) - - # Update time stamp - ts = self.get_global_ts(builder, cond_ptr) - status = builder.insert_value(status, ts, 1) + def generate_update_after_node_execution(self, builder, cond_ptr, node): + # Update time stamp of the last execution + global_ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr) + global_ts = builder.load(global_ts_ptr) - builder.store(status, status_ptr) + node_ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node) + builder.store(global_ts, node_ts_ptr) def generate_ran_this_pass(self, builder, cond_ptr, node): global_trial = self.get_global_trial(builder, cond_ptr)