Skip to content

Commit

Permalink
llvm/scheduler: Drop tracking of the number of node executions
Browse files Browse the repository at this point in the history
Not used anywhere.

Signed-off-by: Jan Vesely <jan.vesely@rutgers.edu>
  • Loading branch information
jvesely committed Oct 31, 2024
1 parent 68b75f6 commit a320087
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 30 deletions.
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
args.append(cond)

builder.call(node_f, args)
cond_gen.generate_update_after_run(builder, cond, node)
cond_gen.generate_update_after_node_execution(builder, cond, node)

builder.block.name = "post_invoke_" + node_f.name

Expand Down
51 changes: 22 additions & 29 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,19 +469,12 @@ def get_private_condition_struct_type(self, composition):

assert len(time_stamp_struct) == len(self.TimeIndex)

status_struct = ir.LiteralStructType([
self.ctx.int32_ty, # number of executions in this run
time_stamp_struct # time stamp of last execution
])
structure = ir.LiteralStructType([
time_stamp_struct, # current time stamp
ir.ArrayType(status_struct, len(composition.nodes)) # for each node
])
return structure
nodes_time_stamps_array = ir.ArrayType(time_stamp_struct, len(composition.nodes))

return ir.LiteralStructType((time_stamp_struct, nodes_time_stamps_array))

def get_private_condition_initializer(self, composition):
return ((0, 0, 0),
tuple((0, (-1, -1, -1)) for _ in composition.nodes))
return ((0, 0, 0), tuple((-1, -1, -1) for _ in composition.nodes))

def get_condition_struct_type(self, node=None):
node = self.composition if node is None else node
Expand All @@ -507,14 +500,14 @@ def bump_ts(self, builder, cond_ptr, count=(0, 0, 1)):
"""
Increments the time structure of the composition.
Count should be a tuple where there is a number in only one spot, and zeroes elsewhere.
Indices greater than that of the one are zeroed.
Indices greater than the incremented one are zeroed.
"""

# Only one element should be non-zero
assert count.count(0) == len(count) - 1

# Get timestruct pointer
ts_ptr = builder.gep(cond_ptr, [self._zero, self._zero, self._zero])
ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr)
ts = builder.load(ts_ptr)

assert len(ts.type) == len(count)
Expand Down Expand Up @@ -556,13 +549,20 @@ def ts_compare(self, builder, ts1, ts2, comp):

return result

def __get_node_status_ptr(self, builder, cond_ptr, node):
def __get_global_ts_ptr(self, builder, cond_ptr):
# derefence the structure, the first element (private structure),
# and the first element of the private strucutre is the global ts.
return builder.gep(cond_ptr, [self._zero, self._zero, self._zero])

def __get_node_ts_ptr(self, builder, cond_ptr, node):
node_idx = self.ctx.int32_ty(self.composition.nodes.index(node))

# derefence the structure, the first element (private structure), the
# second element is the node time stamp array, use index in the array
return builder.gep(cond_ptr, [self._zero, self._zero, self.ctx.int32_ty(1), node_idx])

def __get_node_ts(self, builder, cond_ptr, node):
status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node)
ts_ptr = builder.gep(status_ptr, [self.ctx.int32_ty(0), self.ctx.int32_ty(1)])
ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node)
return builder.load(ts_ptr)

def get_global_ts(self, builder, cond_ptr):
Expand All @@ -582,20 +582,13 @@ def get_global_pass(self, builder, cond_ptr):
def get_global_step(self, builder, cond_ptr):
return self._extract_global_time(builder, cond_ptr, self.TimeIndex.STEP)

def generate_update_after_run(self, builder, cond_ptr, node):
status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node)
status = builder.load(status_ptr)

# Update total number of runs
runs = builder.extract_value(status, 0)
runs = builder.add(runs, runs.type(1))
status = builder.insert_value(status, runs, 0)

# Update time stamp
ts = self.get_global_ts(builder, cond_ptr)
status = builder.insert_value(status, ts, 1)
def generate_update_after_node_execution(self, builder, cond_ptr, node):
# Update time stamp of the last execution
global_ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr)
global_ts = builder.load(global_ts_ptr)

builder.store(status, status_ptr)
node_ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node)
builder.store(global_ts, node_ts_ptr)

def generate_ran_this_pass(self, builder, cond_ptr, node):
global_trial = self.get_global_trial(builder, cond_ptr)
Expand Down

0 comments on commit a320087

Please sign in to comment.