diff --git a/conftest.py b/conftest.py index ee72480ac44..a23219aac1d 100644 --- a/conftest.py +++ b/conftest.py @@ -205,24 +205,10 @@ def cuda_param(val): @pytest.helpers.register def get_func_execution(func, func_mode): if func_mode == 'LLVM': - ex = pnlvm.execution.FuncExecution(func) - - # Calling writeback here will replace parameter values - # with numpy instances that share memory with the binary - # structure used by the compiled function - ex.writeback_state_to_pnl() - - return ex.execute + return pnlvm.execution.FuncExecution(func).execute elif func_mode == 'PTX': - ex = pnlvm.execution.FuncExecution(func) - - # Calling writeback here will replace parameter values - # with numpy instances that share memory with the binary - # structure used by the compiled function - ex.writeback_state_to_pnl() - - return ex.cuda_execute + return pnlvm.execution.FuncExecution(func).cuda_execute elif func_mode == 'Python': return func.function @@ -232,29 +218,16 @@ def get_func_execution(func, func_mode): @pytest.helpers.register def get_mech_execution(mech, mech_mode): if mech_mode == 'LLVM': - ex = pnlvm.execution.MechExecution(mech) - - # Calling writeback here will replace parameter values - # with numpy instances that share memory with the binary - # structure used by the compiled function - ex.writeback_state_to_pnl() - - return ex.execute + return pnlvm.execution.MechExecution(mech).execute elif mech_mode == 'PTX': - ex = pnlvm.execution.MechExecution(mech) - - # Calling writeback here will replace parameter values - # with numpy instances that share memory with the binary - # structure used by the compiled function - ex.writeback_state_to_pnl() - - return ex.cuda_execute + return pnlvm.execution.MechExecution(mech).cuda_execute elif mech_mode == 'Python': def mech_wrapper(x): mech.execute(x) return mech.output_values + return mech_wrapper else: assert False, "Unknown mechanism mode: {}".format(mech_mode) diff --git a/psyneulink/core/llvm/execution.py b/psyneulink/core/llvm/execution.py index 3f7318f0d3e..60e501b19da 100644 --- a/psyneulink/core/llvm/execution.py +++ b/psyneulink/core/llvm/execution.py @@ -113,29 +113,21 @@ def _get_compilation_param(self, name, init_method, arg): _pretty_size(ctypes.sizeof(struct_ty)), ")", "for", self._obj.name) - if len(self._execution_contexts) == 1: if name == '_state': - self.writeback_state_to_pnl() + self._copy_params_to_pnl(self._execution_contexts[0], + self._obj, + self._state_struct, + "llvm_state_ids") + elif name == '_param': - self.writeback_params_to_pnl() + self._copy_params_to_pnl(self._execution_contexts[0], + self._obj, + self._param_struct, + "llvm_param_ids") return struct - def writeback_state_to_pnl(self): - - self._copy_params_to_pnl(self._execution_contexts[0], - self._obj, - self._state_struct, - "llvm_state_ids") - - def writeback_params_to_pnl(self): - - self._copy_params_to_pnl(self._execution_contexts[0], - self._obj, - self._param_struct, - "llvm_param_ids") - def _copy_params_to_pnl(self, context, component, params, ids:str): for idx, attribute in enumerate(getattr(component, ids)): @@ -222,12 +214,7 @@ def _enumerate_recurse(elements): except ValueError: pass - pnl_param.set( - value, - context=context, - override=True, - compilation_sync=True, - ) + pnl_param.set(value, context=context, override=True, compilation_sync=True) class CUDAExecution(Execution):