From 443c6e1f79ee43c09bbea3beb95b8050a5045004 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Wed, 20 Nov 2024 16:31:20 -0500 Subject: [PATCH] llvm/OneHot: Implement support for RANDOM tie resolution Signed-off-by: Jan Vesely --- .../functions/nonstateful/selectionfunctions.py | 16 ++++++++++++++-- psyneulink/core/llvm/builder_context.py | 17 +++++++++++++++++ tests/functions/test_selection.py | 2 -- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py index 28b3830c03..dd652265a8 100644 --- a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py @@ -428,9 +428,9 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, sum_ptr = builder.alloca(ctx.float_ty) builder.store(sum_ptr.type.pointee(-0.0), sum_ptr) - random_draw_ptr = builder.alloca(ctx.float_ty) rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) rng_f = ctx.get_uniform_dist_function_by_state(rand_state_ptr) + random_draw_ptr = builder.alloca(rng_f.args[-1].type.pointee) builder.call(rng_f, [rand_state_ptr, random_draw_ptr]) random_draw = builder.load(random_draw_ptr) @@ -534,8 +534,20 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, extreme_start = num_extremes_ptr.type.pointee(0) extreme_stop = builder.load(num_extremes_ptr) + elif tie == RANDOM: + rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) + rand_f = ctx.get_rand_int_function_by_state(rand_state_ptr) + random_draw_ptr = builder.alloca(rand_f.args[-1].type.pointee) + num_extremes = builder.load(num_extremes_ptr) + + builder.call(rand_f, [rand_state_ptr, ctx.int32_ty(0), num_extremes, random_draw_ptr]) + + extreme_start = builder.load(random_draw_ptr) + extreme_start = builder.trunc(extreme_start, ctx.int32_ty) + extreme_stop = builder.add(extreme_start, extreme_start.type(1)) + else: - assert False + assert False, "Unknown tie resolution: {}".format(tie) extreme_val = builder.load(extreme_val_ptr) diff --git a/psyneulink/core/llvm/builder_context.py b/psyneulink/core/llvm/builder_context.py index 0dcb6bae85..7fcd4224cd 100644 --- a/psyneulink/core/llvm/builder_context.py +++ b/psyneulink/core/llvm/builder_context.py @@ -210,29 +210,46 @@ def init_builtins(self): if "time_stat" in debug_env: print("Time to setup PNL builtins: {}".format(finish - start)) + def get_rand_int_function_by_state(self, state): + if len(state.type.pointee) == 5: + return self.import_llvm_function("__pnl_builtin_mt_rand_int32_bounded") + + elif len(state.type.pointee) == 7: + # we have different versions based on selected FP precision + return self.import_llvm_function("__pnl_builtin_philox_rand_int32_bounded") + + else: + assert False, "Unknown PRNG type!" + def get_uniform_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_double") + elif len(state.type.pointee) == 7: # we have different versions based on selected FP precision return self.import_llvm_function("__pnl_builtin_philox_rand_{}".format(str(self.float_ty))) + else: assert False, "Unknown PRNG type!" def get_binomial_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_binomial") + elif len(state.type.pointee) == 7: return self.import_llvm_function("__pnl_builtin_philox_rand_binomial") + else: assert False, "Unknown PRNG type!" def get_normal_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_normal") + elif len(state.type.pointee) == 7: # Normal exists only for self.float_ty return self.import_llvm_function("__pnl_builtin_philox_rand_normal") + else: assert False, "Unknown PRNG type!" diff --git a/tests/functions/test_selection.py b/tests/functions/test_selection.py index 5ee96801fc..dea0ab9e06 100644 --- a/tests/functions/test_selection.py +++ b/tests/functions/test_selection.py @@ -254,8 +254,6 @@ def test_basic(func, variable, params, expected, benchmark, func_mode): ], ids=lambda x: x if isinstance(x, str) else str(getattr(x, 'shape', '')) ) @pytest.mark.parametrize("indicator", ["indicator", "value"]) def test_one_hot_mode_deterministic(benchmark, variable, tie, indicator, direction, abs_val, expected, func_mode): - if func_mode != "Python" and tie == kw.RANDOM: - pytest.skip("not implemented") f = pnl.OneHot(default_variable=np.zeros_like(variable), mode=kw.DETERMINISTIC,