Skip to content

Commit

Permalink
llvm/OneHot: Implement support for RANDOM tie resolution
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Vesely <jan.vesely@rutgers.edu>
  • Loading branch information
jvesely committed Nov 20, 2024
1 parent 157d139 commit 443c6e1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
sum_ptr = builder.alloca(ctx.float_ty)
builder.store(sum_ptr.type.pointee(-0.0), sum_ptr)

random_draw_ptr = builder.alloca(ctx.float_ty)
rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params)
rng_f = ctx.get_uniform_dist_function_by_state(rand_state_ptr)
random_draw_ptr = builder.alloca(rng_f.args[-1].type.pointee)
builder.call(rng_f, [rand_state_ptr, random_draw_ptr])
random_draw = builder.load(random_draw_ptr)

Expand Down Expand Up @@ -534,8 +534,20 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
extreme_start = num_extremes_ptr.type.pointee(0)
extreme_stop = builder.load(num_extremes_ptr)

elif tie == RANDOM:
rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params)
rand_f = ctx.get_rand_int_function_by_state(rand_state_ptr)
random_draw_ptr = builder.alloca(rand_f.args[-1].type.pointee)
num_extremes = builder.load(num_extremes_ptr)

builder.call(rand_f, [rand_state_ptr, ctx.int32_ty(0), num_extremes, random_draw_ptr])

extreme_start = builder.load(random_draw_ptr)
extreme_start = builder.trunc(extreme_start, ctx.int32_ty)
extreme_stop = builder.add(extreme_start, extreme_start.type(1))

else:
assert False
assert False, "Unknown tie resolution: {}".format(tie)


extreme_val = builder.load(extreme_val_ptr)
Expand Down
17 changes: 17 additions & 0 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,29 +210,46 @@ def init_builtins(self):
if "time_stat" in debug_env:
print("Time to setup PNL builtins: {}".format(finish - start))

def get_rand_int_function_by_state(self, state):
if len(state.type.pointee) == 5:
return self.import_llvm_function("__pnl_builtin_mt_rand_int32_bounded")

elif len(state.type.pointee) == 7:
# we have different versions based on selected FP precision
return self.import_llvm_function("__pnl_builtin_philox_rand_int32_bounded")

else:
assert False, "Unknown PRNG type!"

def get_uniform_dist_function_by_state(self, state):
if len(state.type.pointee) == 5:
return self.import_llvm_function("__pnl_builtin_mt_rand_double")

elif len(state.type.pointee) == 7:
# we have different versions based on selected FP precision
return self.import_llvm_function("__pnl_builtin_philox_rand_{}".format(str(self.float_ty)))

else:
assert False, "Unknown PRNG type!"

def get_binomial_dist_function_by_state(self, state):
if len(state.type.pointee) == 5:
return self.import_llvm_function("__pnl_builtin_mt_rand_binomial")

elif len(state.type.pointee) == 7:
return self.import_llvm_function("__pnl_builtin_philox_rand_binomial")

else:
assert False, "Unknown PRNG type!"

def get_normal_dist_function_by_state(self, state):
if len(state.type.pointee) == 5:
return self.import_llvm_function("__pnl_builtin_mt_rand_normal")

elif len(state.type.pointee) == 7:
# Normal exists only for self.float_ty
return self.import_llvm_function("__pnl_builtin_philox_rand_normal")

else:
assert False, "Unknown PRNG type!"

Expand Down
2 changes: 0 additions & 2 deletions tests/functions/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def test_basic(func, variable, params, expected, benchmark, func_mode):
], ids=lambda x: x if isinstance(x, str) else str(getattr(x, 'shape', '')) )
@pytest.mark.parametrize("indicator", ["indicator", "value"])
def test_one_hot_mode_deterministic(benchmark, variable, tie, indicator, direction, abs_val, expected, func_mode):
if func_mode != "Python" and tie == kw.RANDOM:
pytest.skip("not implemented")

f = pnl.OneHot(default_variable=np.zeros_like(variable),
mode=kw.DETERMINISTIC,
Expand Down

0 comments on commit 443c6e1

Please sign in to comment.