From 54289e992f79779e3f4c2aae9650732eaca14623 Mon Sep 17 00:00:00 2001 From: Jack Lashner Date: Tue, 6 Feb 2024 15:42:57 -0500 Subject: [PATCH] HWP Emulation and PID agent lockless restructure (#606) * HWP Emulation and PID restructure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More complete PMX emulation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changes based on Brian's review * fix session.data + threading in PID agent * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes from further testing of PID agent * remode PID start mode from docs * Fix tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix mutable defaults * Adds traceback to device-emulator exception * adds time to device emulator logs * add debug print to when readline failes * change order of tests * test * change to update_responses * test update_responses * update_responses docstring * more testing race-conditions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adds session data info to docstring --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/agents/hwp_pid.rst | 3 +- socs/agents/hwp_pid/agent.py | 496 +++++++++--------- socs/common/pmx.py | 10 +- socs/testing/device_emulator.py | 58 +- socs/testing/hwp_emulator.py | 223 ++++++++ .../test_hwp_pid_agent_integration.py | 129 ++--- .../test_hwp_pmx_agent_integration.py | 68 +-- 7 files changed, 572 insertions(+), 415 deletions(-) create mode 100644 socs/testing/hwp_emulator.py diff --git a/docs/agents/hwp_pid.rst b/docs/agents/hwp_pid.rst index d39219603..2743c01f2 100644 --- a/docs/agents/hwp_pid.rst +++ b/docs/agents/hwp_pid.rst @@ -25,8 +25,7 @@ An example site-config-file block:: {'agent-class': 'HWPPIDAgent', 'instance-id': 'hwp-pid', 'arguments': [['--ip', '10.10.10.101'], - ['--port', '2000'], - ['--mode', 'acq']]}, + ['--port', '2000']]}, Docker Compose `````````````` diff --git a/socs/agents/hwp_pid/agent.py b/socs/agents/hwp_pid/agent.py index 369887c9b..5381406d3 100644 --- a/socs/agents/hwp_pid/agent.py +++ b/socs/agents/hwp_pid/agent.py @@ -1,13 +1,96 @@ import argparse +import queue import time +import txaio from ocs import ocs_agent, site_config from ocs.ocs_twisted import TimeoutLock -from twisted.internet import reactor +from twisted.internet import defer, reactor, threads + +txaio.use_twisted() + + +from dataclasses import dataclass import socs.agents.hwp_pid.drivers.pid_controller as pd +def parse_action_result(res): + """ + Parses the result of an action to ensure it is a dictionary so it can be + stored in session.data + """ + if res is None: + return {} + elif isinstance(res, dict): + return res + else: + return {'result': res} + + +def get_pid_state(pid: pd.PID): + return { + "current_freq": pid.get_freq(), + "target_freq": pid.get_target(), + "direction": pid.get_direction(), + } + + +class Actions: + @dataclass + class BaseAction: + def __post_init__(self): + self.deferred = defer.Deferred() + self.log = txaio.make_logger() + + @dataclass + class TuneStop(BaseAction): + def process(self, pid: pd.PID): + pid.tune_stop() + + @dataclass + class TuneFreq(BaseAction): + def process(self, pid: pd.PID): + pid.tune_freq() + + @dataclass + class DeclareFreq(BaseAction): + freq: float + + def process(self, pid: pd.PID): + pid.declare_freq(self.freq) + return {"declared_freq": self.freq} + + @dataclass + class SetPID(BaseAction): + p: float + i: int + d: float + + def process(self, pid: pd.PID): + pid.set_pid([self.p, self.i, self.d]) + + @dataclass + class SetDirection(BaseAction): + direction: str + + def process(self, pid: pd.PID): + pid.set_direction(self.direction) + + @dataclass + class SetScale(BaseAction): + slope: float + offset: float + + def process(self, pid: pd.PID): + pid.set_scale(self.slope, self.offset) + + @dataclass + class GetState(BaseAction): + def process(self, pid: pd.PID): + return get_pid_state(pid) + + class HWPPIDAgent: """Agent to PID control the rotation speed of the CHWP @@ -27,53 +110,81 @@ def __init__(self, agent, ip, port, verbosity): self.ip = ip self.port = port self._verbosity = verbosity > 0 + self.action_queue = queue.Queue() - agg_params = {'frame_length': 60} - self.agent.register_feed( - 'hwppid', record=True, agg_params=agg_params) + agg_params = {"frame_length": 60} + self.agent.register_feed("hwppid", record=True, agg_params=agg_params) - @ocs_agent.param('auto_acquire', default=False, type=bool) - @ocs_agent.param('force', default=False, type=bool) - def init_connection(self, session, params): - """init_connection(auto_acquire=False, force=False) + def _get_data_and_publish(self, pid: pd.PID, session: ocs_agent.OpSession): + data = {"timestamp": time.time(), "block_name": "HWPPID", "data": {}} - **Task** - Initialize connection to PID - Controller. + pid_state = get_pid_state(pid) + data['data'].update(pid_state) + session.data.update(pid_state) + session.data['last_updated'] = time.time() + self.agent.publish_to_feed("hwppid", data) - Parameters: - auto_acquire (bool, optional): Default is False. Starts data - acquisition after initialization if True. - force (bool, optional): Force initialization, even if already - initialized. Defaults to False. + def _process_actions(self, pid): + while not self.action_queue.empty(): + action = self.action_queue.get() + try: + self.log.info(f"Running action {action}") + res = action.process(pid) + threads.blockingCallFromThread( + reactor, action.deferred.callback, res + ) + except Exception as e: + self.log.error(f"Error processing action: {action}") + threads.blockingCallFromThread( + reactor, action.deferred.errback, e + ) + + def _clear_queue(self): + while not self.action_queue.empty(): + action = self.action_queue.get() + action.deferred.errback(Exception("Action cancelled")) + + def main(self, session, params): + """main() + + **Process** - Main Process for PID agent. Periodically queries PID + controller for data, and executes requested actions. + + Notes: + The most recent data collected is stored in the session data in the + structure:: + >>> response.session['data'] + {'current_freq': 0, + 'target_freq': 0, + 'direction': 1, + 'last_updated': 1649085992.719602} """ - if self._initialized and not params['force']: - self.log.info("Connection already initialized. Returning...") - return True, "Connection already initialized" + pid = pd.PID(ip=self.ip, port=self.port, verb=self._verbosity) + self.log.info("Connected to PID controller") - with self.lock.acquire_timeout(10, job='init_connection') as acquired: - if not acquired: - self.log.warn( - 'Could not run init_connection because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' + self._clear_queue() - try: - self.pid = pd.PID(ip=self.ip, port=self.port, - verb=self._verbosity) - self.log.info('Connected to PID controller') - except BrokenPipeError: - self.log.error('Could not establish connection to PID controller') - reactor.callFromThread(reactor.stop) - return False, 'Unable to connect to PID controller' + sample_period = 5.0 + last_sample = 0.0 + session.set_status("running") + while session.status in ["starting", "running"]: + now = time.time() + if now - last_sample > sample_period: + self._get_data_and_publish(pid, session) + last_sample = now - self._initialized = True + self._process_actions(pid) + time.sleep(0.2) - # Start 'acq' Process if requested - if params['auto_acquire']: - self.agent.start('acq') + return True, "Exited main process" - return True, 'Connection to PID controller established' + def _main_stop(self, session, params): + """Stop main process""" + session.set_status("stopping") + return True, "Set main status to stopping" + @defer.inlineCallbacks def tune_stop(self, session, params): """tune_stop() @@ -81,16 +192,13 @@ def tune_stop(self, session, params): optimize the PID parameters for deceleration. """ - with self.lock.acquire_timeout(3, job='tune_stop') as acquired: - if not acquired: - self.log.warn( - 'Could not tune stop because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.tune_stop() - - return True, 'Reversing Direction' + action = Actions.TuneStop(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" + @defer.inlineCallbacks def tune_freq(self, session, params): """tune_freq() @@ -98,17 +206,14 @@ def tune_freq(self, session, params): and optimize the PID parameters for rotation. """ - with self.lock.acquire_timeout(3, job='tune_freq') as acquired: - if not acquired: - self.log.warn( - 'Could not tune freq because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.tune_freq() - - return True, 'Tuning to setpoint' - - @ocs_agent.param('freq', default=0., check=lambda x: 0. <= x <= 3.0) + action = Actions.TuneFreq(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" + + @defer.inlineCallbacks + @ocs_agent.param("freq", default=0.0, check=lambda x: 0.0 <= x <= 3.0) def declare_freq(self, session, params): """declare_freq(freq=0) @@ -118,20 +223,22 @@ def declare_freq(self, session, params): Parameters: freq (float): Desired HWP rotation frequency - """ - with self.lock.acquire_timeout(3, job='declare_freq') as acquired: - if not acquired: - self.log.warn( - 'Could not declare freq because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.declare_freq(params['freq']) - - return True, 'Setpoint at {} Hz'.format(params['freq']) + Notes: + Session data is structured as follows:: - @ocs_agent.param('p', default=0.2, type=float, check=lambda x: 0. < x <= 8.) - @ocs_agent.param('i', default=63, type=int, check=lambda x: 0 <= x <= 200) - @ocs_agent.param('d', default=0., type=float, check=lambda x: 0. <= x < 10.) + >>> response.session['data'] + {'declared_freq': 2.0} + """ + action = Actions.DeclareFreq(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" + + @defer.inlineCallbacks + @ocs_agent.param("p", default=0.2, type=float, check=lambda x: 0.0 < x <= 8.0) + @ocs_agent.param("i", default=63, type=int, check=lambda x: 0 <= x <= 200) + @ocs_agent.param("d", default=0.0, type=float, check=lambda x: 0.0 <= x < 10.0) def set_pid(self, session, params): """set_pid(p=0.2, i=63, d=0.) @@ -143,76 +250,15 @@ def set_pid(self, session, params): p (float): Proportional PID value i (int): Integral PID value d (float): Derivative PID value - - """ - with self.lock.acquire_timeout(3, job='set_pid') as acquired: - if not acquired: - self.log.warn( - 'Could not set pid because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.set_pid( - [params['p'], params['i'], params['d']]) - - return True, f"Set PID params to p: {params['p']}, i: {params['i']}, d: {params['d']}" - - def get_freq(self, session, params): - """get_freq() - - **Task** - Return the current HWP frequency as seen by the PID - controller. - - """ - with self.lock.acquire_timeout(3, job='get_freq') as acquired: - if not acquired: - self.log.warn( - 'Could not get freq because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - freq = self.pid.get_freq() - session.data = { - 'freq': freq, - 'timestamp': time.time(), - } - - return True, 'Current frequency = {}'.format(freq) - - def get_target(self, session, params): - """get_target() - - **Task** - Return the target HWP frequency of the PID - controller. - - """ - with self.lock.acquire_timeout(3, job='get_target') as acquired: - if not acquired: - self.log.warn( - 'Could not get freq because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - freq = self.pid.get_target() - - return True, 'Target frequency = {}'.format(freq) - - def get_direction(self, session, params): - """get_direction() - - **Task** - Return the current HWP tune direction as seen by the PID - controller. - """ - with self.lock.acquire_timeout(3, job='get_direction') as acquired: - if not acquired: - self.log.warn( - 'Could not get freq because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - direction = self.pid.get_direction() - session.data = {'direction': direction} - - return True, 'Current direction = {}'.format(['Forward', 'Reverse'][direction]) - - @ocs_agent.param('direction', type=str, default='0', choices=['0', '1']) + action = Actions.SetPID(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" + + @defer.inlineCallbacks + @ocs_agent.param("direction", type=str, default="0", choices=["0", "1"]) def set_direction(self, session, params): """set_direction(direction='0') @@ -222,18 +268,17 @@ def set_direction(self, session, params): direction (str): '0' for forward and '1' for reverse. """ - with self.lock.acquire_timeout(3, job='set_direction') as acquired: - if not acquired: - self.log.warn( - 'Could not set direction because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.set_direction(params['direction']) - - return True, 'Set direction' - - @ocs_agent.param('slope', default=1., type=float, check=lambda x: -10. < x < 10.) - @ocs_agent.param('offset', default=0.1, type=float, check=lambda x: -10. < x < 10.) + action = Actions.SetDirection(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" + + @defer.inlineCallbacks + @ocs_agent.param("slope", default=1.0, type=float, check=lambda x: -10.0 < x < 10.0) + @ocs_agent.param( + "offset", default=0.1, type=float, check=lambda x: -10.0 < x < 10.0 + ) def set_scale(self, session, params): """set_scale(slope=1, offset=0.1) @@ -247,87 +292,31 @@ def set_scale(self, session, params): voltage" relationship """ - with self.lock.acquire_timeout(3, job='set_scale') as acquired: - if not acquired: - self.log.warn( - 'Could not set scale because {} is already running'.format(self.lock.job)) - return False, 'Could not acquire lock' - - self.pid.set_scale(params['slope'], params['offset']) + action = Actions.SetScale(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" - return True, 'Set scale' + @defer.inlineCallbacks + def get_state(self, session, params): + """get_state() - def acq(self, session, params): - """acq() - - **Process** - Start PID data acquisition. + **Task** - Polls hardware for the current the PID state. Notes: - The most recent data collected is stored in the session data in the - structure:: + Session data for this operation is as follows:: >>> response.session['data'] {'current_freq': 0, 'target_freq': 0, - 'direction': 1, - 'last_updated': 1649085992.719602} - - """ - with self.lock.acquire_timeout(timeout=10, job='acq') as acquired: - if not acquired: - self.log.warn('Could not start iv acq because {} is already running' - .format(self.lock.job)) - return False, 'Could not acquire lock' - - session.set_status('running') - last_release = time.time() - self.take_data = True - - while self.take_data: - # Relinquish sampling lock occasionally. - if time.time() - last_release > 1.: - last_release = time.time() - if not self.lock.release_and_acquire(timeout=10): - self.log.warn(f"Failed to re-acquire sampling lock, " - f"currently held by {self.lock.job}.") - continue - - data = {'timestamp': time.time(), - 'block_name': 'HWPPID', 'data': {}} - - try: - current_freq = self.pid.get_freq() - target_freq = self.pid.get_target() - direction = self.pid.get_direction() - - data['data']['current_freq'] = current_freq - data['data']['target_freq'] = target_freq - data['data']['direction'] = direction - except BaseException: - time.sleep(1) - continue - - self.agent.publish_to_feed('hwppid', data) - - session.data = {'current_freq': current_freq, - 'target_freq': target_freq, - 'direction': direction, - 'last_updated': time.time()} - - time.sleep(5) - - self.agent.feeds['hwppid'].flush_buffer() - return True, 'Acqusition exited cleanly' - - def _stop_acq(self, session, params): - """ - Stop acq process. + 'direction': 1} """ - if self.take_data: - self.take_data = False - return True, 'requested to stop taking data' - - return False, 'acq is not currently running' + action = Actions.GetState(**params) + self.action_queue.put(action) + res = yield action.deferred + session.data = parse_action_result(res) + return True, f"Completed: {str(action)}" def make_parser(parser=None): @@ -339,49 +328,44 @@ def make_parser(parser=None): parser = argparse.ArgumentParser() # Add options specific to this agent - pgroup = parser.add_argument_group('Agent Options') - pgroup.add_argument('--ip') - pgroup.add_argument('--port') - pgroup.add_argument('--verbose', '-v', action='count', default=0, - help='PID Controller verbosity level.') - pgroup.add_argument('--mode', type=str, default='acq', - choices=['init', 'acq'], - help="Starting operation for the Agent.") + pgroup = parser.add_argument_group("Agent Options") + pgroup.add_argument("--ip") + pgroup.add_argument("--port") + pgroup.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="PID Controller verbosity level.", + ) + pgroup.add_argument("--mode") return parser def main(args=None): parser = make_parser() - args = site_config.parse_args(agent_class='HWPPIDAgent', - parser=parser, - args=args) - - init_params = False - if args.mode == 'init': - init_params = {'auto_acquire': False} - elif args.mode == 'acq': - init_params = {'auto_acquire': True} + args = site_config.parse_args(agent_class="HWPPIDAgent", parser=parser, args=args) agent, runner = ocs_agent.init_site_agent(args) - hwppid_agent = HWPPIDAgent(agent, ip=args.ip, - port=args.port, - verbosity=args.verbose) - agent.register_task('init_connection', hwppid_agent.init_connection, - startup=init_params) - agent.register_process('acq', hwppid_agent.acq, - hwppid_agent._stop_acq) - agent.register_task('tune_stop', hwppid_agent.tune_stop) - agent.register_task('tune_freq', hwppid_agent.tune_freq) - agent.register_task('declare_freq', hwppid_agent.declare_freq) - agent.register_task('set_pid', hwppid_agent.set_pid) - agent.register_task('get_freq', hwppid_agent.get_freq) - agent.register_task('get_target', hwppid_agent.get_target) - agent.register_task('get_direction', hwppid_agent.get_direction) - agent.register_task('set_direction', hwppid_agent.set_direction) - agent.register_task('set_scale', hwppid_agent.set_scale) + if args.mode is not None: + agent.log.warn("--mode agrument is deprecated.") + + hwppid_agent = HWPPIDAgent( + agent, ip=args.ip, port=args.port, verbosity=args.verbose + ) + agent.register_process( + "main", hwppid_agent.main, hwppid_agent._main_stop, startup=True + ) + agent.register_task("tune_stop", hwppid_agent.tune_stop, blocking=False) + agent.register_task("tune_freq", hwppid_agent.tune_freq, blocking=False) + agent.register_task("declare_freq", hwppid_agent.declare_freq, blocking=False) + agent.register_task("set_pid", hwppid_agent.set_pid, blocking=False) + agent.register_task("set_direction", hwppid_agent.set_direction, blocking=False) + agent.register_task("set_scale", hwppid_agent.set_scale, blocking=False) + agent.register_task("get_state", hwppid_agent.get_state, blocking=False) runner.run(agent, auto_reconnect=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/socs/common/pmx.py b/socs/common/pmx.py index 9da5fd54e..3c6f25665 100644 --- a/socs/common/pmx.py +++ b/socs/common/pmx.py @@ -79,14 +79,16 @@ def check_current(self): self.ser.write("MEAS:CURR?\n\r") self.wait() try: - val = float(self.ser.readline()) - msg = "Measured current = %.3f A" % (val) + val = self.ser.readline() + curr = float(val) + msg = "Measured current = %.3f A" % (curr) # print(msg) except ValueError: - val = -999. + print(f"Could not convert '{val}' to float") + curr = -999. msg = 'WARNING! Could not get correct current value! | Response = "%s"' % (val) print(msg) - return msg, val + return msg, curr def check_voltage_current(self): """Check both the voltage and current.""" diff --git a/socs/testing/device_emulator.py b/socs/testing/device_emulator.py index 3184ce4ae..0999d3672 100644 --- a/socs/testing/device_emulator.py +++ b/socs/testing/device_emulator.py @@ -1,8 +1,12 @@ +import logging import shutil import socket import subprocess import threading import time +import traceback as tb +from copy import deepcopy +from typing import Dict import pytest import serial @@ -80,13 +84,21 @@ class DeviceEmulator: """ def __init__(self, responses, encoding='utf-8'): - self.responses = responses + self.responses = deepcopy(responses) self.default_response = None self.encoding = encoding self._type = None self._read = True self._conn = None + self.logger = logging.getLogger(self.__class__.__name__) + self.logger.setLevel(logging.DEBUG) + if len(self.logger.handlers) == 0: + formatter = logging.Formatter("%(asctime)s - %(name)s: %(message)s") + handler = logging.StreamHandler() + handler.setFormatter(formatter) + self.logger.addHandler(handler) + @staticmethod def _setup_socat(): """Setup a data relay with socat. @@ -132,7 +144,7 @@ def create_serial_relay(self): target=self._read_serial) bkg_read.start() - def _get_response(self, msg): + def get_response(self, msg): """Determine the response to a given message. Args: @@ -155,7 +167,9 @@ def _get_response(self, msg): else: response = self.responses[msg] except Exception as e: - print(f"encountered error {e}") + self.logger.info(f"Responses: {self.responses}") + self.logger.info(f"encountered error {e}") + self.logger.info(tb.format_exc()) response = None return response @@ -172,9 +186,9 @@ def _read_serial(self): msg = self.ser.readline() if self.encoding: msg = msg.strip().decode(self.encoding) - print(f"msg='{msg}'") + self.logger.debug(f"msg='{msg}'") - response = self._get_response(msg) + response = self.get_response(msg) # Avoid user providing bytes-like response if isinstance(response, bytes) and self.encoding is not None: @@ -183,7 +197,7 @@ def _read_serial(self): if response is None: continue - print(f"response='{response}'") + self.logger.debug(f"response='{response}'") if self.encoding: response = (response + '\r\n').encode(self.encoding) self.ser.write(response) @@ -230,26 +244,26 @@ def _read_socket(self, port): self._sock.bind(('127.0.0.1', port)) self._sock_bound = True except OSError: - print(f"Failed to bind to port {port}, trying again...") + self.logger.error(f"Failed to bind to port {port}, trying again...") time.sleep(1) self._sock.listen(1) - print("Device emulator waiting for tcp client connection") + self.logger.info("Device emulator waiting for tcp client connection") self._conn, client_address = self._sock.accept() - print(f"Client connection made from {client_address}") + self.logger.info(f"Client connection made from {client_address}") while self._read: try: msg = self._conn.recv(4096) # Was seeing this on tests in the cryomech agent except ConnectionResetError: - print('Caught connection reset on Agent clean up') + self.logger.info('Caught connection reset on Agent clean up') break if self.encoding: msg = msg.strip().decode(self.encoding) if msg: - print(f"msg='{msg}'") + self.logger.debug(f"msg='{msg}'") - response = self._get_response(msg) + response = self.get_response(msg) # Avoid user providing bytes-like response if isinstance(response, bytes) and self.encoding is not None: @@ -258,7 +272,7 @@ def _read_socket(self, port): if response is None: continue - print(f"response='{response}'") + self.logger.debug(f"response='{response}'") if self.encoding: response = response.encode(self.encoding) self._conn.sendall(response) @@ -298,6 +312,18 @@ def create_tcp_relay(self, port): while not self._sock_bound: time.sleep(0.1) + def update_responses(self, responses: Dict): + """ + Updates the current responses. See ``define_responses`` for more detail. + + Args + ------ + responses: dict + Dict of commands to use to update the current responses. + """ + self.responses.update(responses) + self.logger.info(f"responses set to {self.responses}") + def define_responses(self, responses, default_response=None): """Define what responses are available to reply with on the configured communication relay. @@ -324,7 +350,7 @@ def define_responses(self, responses, default_response=None): ``encoding=None``. """ - print(f"responses set to {responses}") - self.responses = responses - print(f"default response set to '{default_response}'") + self.logger.info(f"responses set to {responses}") + self.responses = deepcopy(responses) + self.logger.info(f"default response set to '{default_response}'") self.default_response = default_response diff --git a/socs/testing/hwp_emulator.py b/socs/testing/hwp_emulator.py new file mode 100644 index 000000000..9a942ec96 --- /dev/null +++ b/socs/testing/hwp_emulator.py @@ -0,0 +1,223 @@ +""" +HWP Emulation module +""" +import logging +import threading +import time +from dataclasses import dataclass, field + +import pytest + +from socs.agents.hwp_pid.drivers.pid_controller import PID +from socs.testing import device_emulator + + +def hex_str_to_dec(hex_value, decimal=3): + """Converts a hex string to a decimal float""" + return float(int(hex_value, 16)) / 10**decimal + + +@dataclass +class PMXState: + """State of the PMX Emulator""" + output: bool = False + current: float = 0 + current_limit: float = 10.0 + voltage_limit: float = 10.0 + voltage: float = 0 + source: str = "volt" + + +@dataclass +class PIDState: + """State of the PID Emulator""" + direction: str = "forward" + freq_setpoint: float = 0.0 + + +@dataclass +class HWPState: + """State of the HWP Emulator""" + cur_freq: float = 0.0 + pmx: PMXState = field(default_factory=PMXState) + pid: PIDState = field(default_factory=PIDState) + lock = threading.Lock() + + +def _create_logger(name, log_level=logging.INFO): + logger = logging.getLogger(name) + logger.setLevel(log_level) + if len(logger.handlers) == 0: + formatter = logging.Formatter("%(name)s: %(message)s") + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +def lerp(start, end, t): + return (1 - t) * start + t * end + + +class HWPEmulator: + def __init__(self, pid_port=None, pmx_port=None, log_level=logging.INFO): + self.pid_port = pid_port + self.pmx_port = pmx_port + + self.state = HWPState() + + self.pid_device = device_emulator.DeviceEmulator([]) + self.pid_device.get_response = self.process_pid_msg + self.pid_device.logger = _create_logger("PID", log_level=log_level) + + self.pmx_device = device_emulator.DeviceEmulator([]) + self.pmx_device.get_response = self.process_pmx_msg + self.pmx_device.logger = _create_logger("PMX", log_level=log_level) + + self.update_thread = threading.Thread(target=self.update_loop) + self.run_update = False + + self.logger = _create_logger("HWP", log_level=log_level) + + def start(self): + """Start up TCP Sockets and update loop""" + if self.pid_port is not None: + self.pid_device.create_tcp_relay(self.pid_port) + if self.pmx_port is not None: + self.pmx_device.create_tcp_relay(self.pmx_port) + + self.update_thread.start() + + def shutdown(self): + """Shutdown TCP Sockets and update loop""" + self.run_update = False + self.pid_device.shutdown() + self.pmx_device.shutdown() + self.update_thread.join() + + def update_loop(self): + """Update HWP state""" + self.run_update = True + s = self.state + self.logger.info("Starting update thread") + + while self.run_update: + with s.lock: + if s.pmx.source == "volt": + s.cur_freq = lerp(s.cur_freq, s.pid.freq_setpoint, 0.3) + + def process_pmx_msg(self, data): + """Process messages for PMX emulator""" + cmd = data.split(" ")[0].strip() + self.logger.debug(cmd) + with self.state.lock: + # Output commands + if cmd == "output": + val = int(data.split(" ")[1].strip()) + self.logger.info("Setting output to %d", val) + self.state.pmx.output = bool(val) + elif cmd == "output:protection:clear": + self.logger.info("Commanded to clear alarms") + elif cmd == "output?": + return str(int(self.state.pmx.output)) + + # Current (limit) commands + elif cmd == "curr": + val = float(data.split(" ")[1].strip()) + self.logger.info("Setting current to %.3f", val) + self.state.pmx.current = val + elif cmd == "curr:prot": + val = float(data.split(" ")[1].strip()) + self.logger.info("Setting current limit to %.3f", val) + self.state.pmx.current_limit = val + elif cmd == "curr?": + return f"{self.state.pmx.current}\n" + elif cmd == "curr:prot?": + return f"{self.state.pmx.current_limit}\n" + elif cmd == "meas:curr?": + return f"{self.state.pmx.current}\n" + + # Voltage (limit) commands + elif cmd == "volt": + val = float(data.split(" ")[1].strip()) + self.logger.info("Setting current to %.3f", val) + self.state.pmx.voltage = val + elif cmd == "volt:prot": + val = float(data.split(" ")[1].strip()) + self.logger.info("Setting voltage limit to %.3f", val) + self.state.pmx.voltage_limit = val + elif cmd == "volt:prot?": + return f"{self.state.pmx.voltage_limit}\n" + elif cmd == "volt?": + return f"{self.state.pmx.voltage}\n" + elif cmd == "meas:volt?": + return f"{self.state.pmx.voltage}\n" + + # Error codes + elif cmd == ":system:error?": # Error codes + return '0,"No error"\n' + elif cmd == "stat:ques?": # Status Codes + return "0" + elif cmd == "volt:ext:sour?": + return f"{self.state.pmx.source}\n" + else: + self.logger.info("Unknown cmd: %s", data) + if "?" in cmd: + return "unknown" + + def process_pid_msg(self, data): + """Process messages for PID emulator""" + logger = self.pid_device.logger + cmd = data.split(" ")[0].strip() + with self.state.lock: + # self.logger.debug(cmd) + if cmd == "*W02400000": + self.state.pid.direction = "reverse" + logger.info("Setting direction: reverse") + return "asdfl" + elif cmd == "*W02401388": + self.state.pid.direction = "forward" + logger.info("Setting direction: forward") + return "asdfl" + elif cmd.startswith("*W014"): + setpt = hex_str_to_dec(cmd[5:], 3) + logger.info("SETPOINT %s Hz", setpt) + self.state.pid.freq_setpoint = setpt + return "sdflsf" + elif cmd == "*X01": # Get frequency + return f"X01{self.state.cur_freq:0.3f}" + elif cmd == "*R01": # Get Target + return f"R01{PID._convert_to_hex(self.state.pid.freq_setpoint, 3)}" + elif cmd == "*R02": # Get Direction + if self.state.pid.direction == "forward": + return "1" + else: + return "0" + else: + self.logger.info("Unknown cmd: %s", cmd) + return "unknown" + + +def create_hwp_emulator_fixture(**kwargs): + """ + Creates a fixture for the HWP Emulator to use in tests. + """ + + @pytest.fixture() + def create_emulator(): + em = HWPEmulator(**kwargs) + em.start() + yield em + em.shutdown() + + return create_emulator + + +if __name__ == "__main__": + hwp_em = HWPEmulator() + try: + hwp_em.start() + while True: + time.sleep(1) + finally: + hwp_em.shutdown() diff --git a/tests/integration/test_hwp_pid_agent_integration.py b/tests/integration/test_hwp_pid_agent_integration.py index 48a0107b2..42c7757a5 100644 --- a/tests/integration/test_hwp_pid_agent_integration.py +++ b/tests/integration/test_hwp_pid_agent_integration.py @@ -1,3 +1,6 @@ +import logging +import time + import ocs import pytest from integration.util import docker_compose_file # noqa: F401 @@ -5,7 +8,7 @@ from ocs.base import OpCode from ocs.testing import create_agent_runner_fixture, create_client_fixture -from socs.testing.device_emulator import create_device_emulator +from socs.testing.hwp_emulator import create_hwp_emulator_fixture wait_for_crossbar = create_crossbar_fixture() run_agent = create_agent_runner_fixture( @@ -13,8 +16,15 @@ run_agent_idle = create_agent_runner_fixture( '../socs/agents/hwp_pid/agent.py', 'hwp_pid_agent', args=['--mode', 'init', '--log-dir', './logs/']) client = create_client_fixture('hwp-pid') -pid_emu = create_device_emulator( - {'*W02400000': 'W02\r'}, relay_type='tcp', port=2000) +hwp_emu = create_hwp_emulator_fixture(pid_port=2000, log_level=logging.DEBUG) + + +def wait_for_main(client): + while True: + data = client.main.status().session['data'] + if 'last_updated' in data: + return + time.sleep(0.2) @pytest.mark.integtest @@ -24,67 +34,32 @@ def test_testing(wait_for_crossbar): @pytest.mark.integtest -def test_hwp_rotation_failed_connection_pid(wait_for_crossbar, run_agent_idle, client): - resp = client.init_connection.start() - print(resp) - # We can't really check anything here, the agent's going to exit during the - # init_conneciton task because it cannot connect to the PID controller. - - -@pytest.mark.integtest -def test_hwp_rotation_get_direction(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*R02': 'R02400000\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made - resp = client.get_direction() - print(resp) - assert resp.status == ocs.OK - print(resp.session) - assert resp.session['op_code'] == OpCode.SUCCEEDED.value - - # Test when in reverse - responses = {'*W02401388': 'W02\r', - '*R02': 'R02401388\r'} - pid_emu.define_responses(responses) - - client.set_direction(direction='1') - resp = client.get_direction() - print(resp) - assert resp.status == ocs.OK - print(resp.session) - assert resp.session['op_code'] == OpCode.SUCCEEDED.value +def test_hwp_rotation_get_state(wait_for_crossbar, hwp_emu, run_agent, client): + resp = client.get_state() + state = resp.session['data'] + print(state) + assert len(state.keys()) > 0 @pytest.mark.integtest -def test_hwp_rotation_set_direction(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*W02400000': 'W02\r', - '*W02401388': 'W02\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_set_direction(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.set_direction(direction='0') - print(resp) assert resp.status == ocs.OK - print(resp.session) assert resp.session['op_code'] == OpCode.SUCCEEDED.value + data = client.get_state().session['data'] + assert data['direction'] == '0' resp = client.set_direction(direction='1') - print(resp) assert resp.status == ocs.OK - print(resp.session) assert resp.session['op_code'] == OpCode.SUCCEEDED.value + data = client.get_state().session['data'] + assert data['direction'] == '1' @pytest.mark.integtest -def test_hwp_rotation_set_pid(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*W1700C8': 'W17\r', - '*W18003F': 'W18\r', - '*W190000': 'W19\r', - '*Z02': 'Z02\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_set_pid(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.set_pid(p=0.2, i=63, d=0) print(resp) assert resp.status == ocs.OK @@ -93,17 +68,8 @@ def test_hwp_rotation_set_pid(wait_for_crossbar, pid_emu, run_agent, client): @pytest.mark.integtest -def test_hwp_rotation_tune_stop(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*W0C83': 'W0C\r', - '*W01400000': 'W01\r', - '*R01': 'R01400000\r', - '*Z02': 'Z02\r', - '*W1700C8': 'W17\r', - '*W180000': 'W18\r', - '*W190000': 'W19\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_tune_stop(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.tune_stop() print(resp) assert resp.status == ocs.OK @@ -112,26 +78,8 @@ def test_hwp_rotation_tune_stop(wait_for_crossbar, pid_emu, run_agent, client): @pytest.mark.integtest -def test_hwp_rotation_get_freq(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*X01': 'X010.000\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made - resp = client.get_freq() - print(resp) - assert resp.status == ocs.OK - print(resp.session) - assert resp.session['op_code'] == OpCode.SUCCEEDED.value - - -@pytest.mark.integtest -def test_hwp_rotation_set_scale(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*W14102710': 'W14\r', - '*W03302710': 'W03\r', - '*Z02': 'Z02\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_set_scale(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.set_scale() print(resp) assert resp.status == ocs.OK @@ -140,8 +88,8 @@ def test_hwp_rotation_set_scale(wait_for_crossbar, pid_emu, run_agent, client): @pytest.mark.integtest -def test_hwp_rotation_declare_freq(wait_for_crossbar, pid_emu, run_agent, client): - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_declare_freq(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.declare_freq(freq=0) print(resp) assert resp.status == ocs.OK @@ -150,17 +98,8 @@ def test_hwp_rotation_declare_freq(wait_for_crossbar, pid_emu, run_agent, client @pytest.mark.integtest -def test_hwp_rotation_tune_freq(wait_for_crossbar, pid_emu, run_agent, client): - responses = {'*W0C81': 'W0C\r', - '*W01400000': 'W01\r', - '*R01': 'R01400000\r', - '*Z02': 'Z02\r', - '*W1700C8': 'W17\r', - '*W18003F': 'W18\r', - '*W190000': 'W19\r'} - pid_emu.define_responses(responses) - - client.init_connection.wait() # wait for connection to be made +def test_hwp_rotation_tune_freq(wait_for_crossbar, hwp_emu, run_agent, client): + wait_for_main(client) resp = client.tune_freq() print(resp) assert resp.status == ocs.OK diff --git a/tests/integration/test_hwp_pmx_agent_integration.py b/tests/integration/test_hwp_pmx_agent_integration.py index 23bf35a97..5fbc88af9 100644 --- a/tests/integration/test_hwp_pmx_agent_integration.py +++ b/tests/integration/test_hwp_pmx_agent_integration.py @@ -12,15 +12,15 @@ run_agent_idle = create_agent_runner_fixture( '../socs/agents/hwp_pmx/agent.py', 'hwp_pmx_agent', args=['--mode', 'idle', '--log-dir', './logs/']) client = create_client_fixture('hwp-pmx') -kikusui_emu = create_device_emulator({}, relay_type='tcp', port=5025) -default_responses = { +responses = { 'meas:volt?': '2', 'meas:curr?': '1', ':system:error?': '+0,"No error"\n', 'stat:ques?': '0', 'volt:ext:sour?': 'source_name' } +kikusui_emu = create_device_emulator(responses, relay_type='tcp', port=5025) @pytest.mark.integtest @@ -31,9 +31,6 @@ def test_testing(wait_for_crossbar): @pytest.mark.integtest def test_hwp_rotation_main(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses['meas:curr?'] = '1' - kikusui_emu.define_responses(responses) client.main.stop() resp = client.main.wait() assert resp.session['data']['curr'] == 1.0 @@ -42,14 +39,11 @@ def test_hwp_rotation_main(wait_for_crossbar, kikusui_emu, run_agent, client): @pytest.mark.integtest -def test_hwp_rotation_set_on(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({ - 'output 1': '', - 'output?': '1' +def test_hwp_rotation_set_off(wait_for_crossbar, kikusui_emu, run_agent, client): + kikusui_emu.update_responses({ + 'output 0': '', 'output?': '0', }) - kikusui_emu.define_responses(responses) - resp = client.set_on() + resp = client.set_off() print(resp) print(resp.session) assert resp.status == ocs.OK @@ -57,13 +51,11 @@ def test_hwp_rotation_set_on(wait_for_crossbar, kikusui_emu, run_agent, client): @pytest.mark.integtest -def test_hwp_rotation_set_off(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({ - 'output 0': '', 'output?': '0' +def test_hwp_rotation_set_on(wait_for_crossbar, kikusui_emu, run_agent, client): + kikusui_emu.update_responses({ + 'output 0': '', 'output?': '0', }) - kikusui_emu.define_responses(responses) - resp = client.set_off() + resp = client.set_on() print(resp) print(resp.session) assert resp.status == ocs.OK @@ -72,10 +64,9 @@ def test_hwp_rotation_set_off(wait_for_crossbar, kikusui_emu, run_agent, client) @pytest.mark.integtest def test_hwp_rotation_set_i(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({'curr 1.000000': '', - 'curr?': '1.000000'}) - kikusui_emu.define_responses(responses) + kikusui_emu.update_responses( + {'curr 1.000000': '', 'curr?': '1.000000'} + ) resp = client.set_i(curr=1) print(resp) print(resp.session) @@ -85,10 +76,8 @@ def test_hwp_rotation_set_i(wait_for_crossbar, kikusui_emu, run_agent, client): @pytest.mark.integtest def test_hwp_rotation_set_i_lim(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({'curr:prot 2.000000': '', - 'curr:prot?': '2.000000'}) - kikusui_emu.define_responses(responses) + kikusui_emu.update_responses( + {'curr:prot 2.000000': '', 'curr:prot?': '2.000000'}) resp = client.set_i_lim(curr=2) print(resp) print(resp.session) @@ -98,10 +87,8 @@ def test_hwp_rotation_set_i_lim(wait_for_crossbar, kikusui_emu, run_agent, clien @pytest.mark.integtest def test_hwp_rotation_set_v(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({'volt 1.000000': '', - 'volt?': '1.000000'}) - kikusui_emu.define_responses(responses) + kikusui_emu.update_responses( + {'volt 1.000000': '', 'volt?': '1.000000'}) resp = client.set_v(volt=1) print(resp) print(resp.session) @@ -111,12 +98,10 @@ def test_hwp_rotation_set_v(wait_for_crossbar, kikusui_emu, run_agent, client): @pytest.mark.integtest def test_hwp_rotation_set_v_lim(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({ + kikusui_emu.update_responses({ 'volt:prot 10.0': '', 'volt:prot?': '10.000000' }) - kikusui_emu.define_responses(responses) resp = client.set_v_lim(volt=10) print(resp) print(resp.session) @@ -126,10 +111,10 @@ def test_hwp_rotation_set_v_lim(wait_for_crossbar, kikusui_emu, run_agent, clien @pytest.mark.integtest def test_hwp_rotation_use_ext(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({'volt:ext:sour VOLT': '', - 'volt:ext:sour?': 'source_name'}) - kikusui_emu.define_responses(responses) + kikusui_emu.update_responses({ + 'volt:ext:sour VOLT': '', + 'volt:ext:sour?': 'source_name', + }) resp = client.use_ext() print(resp) print(resp.session) @@ -139,11 +124,10 @@ def test_hwp_rotation_use_ext(wait_for_crossbar, kikusui_emu, run_agent, client) @pytest.mark.integtest def test_hwp_rotation_ign_ext(wait_for_crossbar, kikusui_emu, run_agent, client): - responses = default_responses.copy() - responses.update({'volt:ext:sour NONE': '', - 'volt:ext:sour?': 'False'}) - kikusui_emu.define_responses(responses) - + kikusui_emu.update_responses({ + 'volt:ext:sour NONE': '', + 'volt:ext:sour?': 'False', + }) resp = client.ign_ext() print(resp) print(resp.session)