From 1c9f28ca79e9bafa218721971a9cc2aac8ff09b2 Mon Sep 17 00:00:00 2001 From: koparasy Date: Mon, 12 Aug 2024 10:01:31 -0700 Subject: [PATCH] Pushes model updates to the stager --- examples/bnm_opt/binomial_options.cpp | 47 +++- examples/tools/deploy_sequential.py | 301 +------------------------- src/AMSWorkflow/ams/orchestrator.py | 3 + src/AMSWorkflow/ams/rmq.py | 5 +- src/AMSWorkflow/ams/stage.py | 230 ++++++++++++++------ 5 files changed, 217 insertions(+), 369 deletions(-) diff --git a/examples/bnm_opt/binomial_options.cpp b/examples/bnm_opt/binomial_options.cpp index b38f51c8..a2fa1158 100644 --- a/examples/bnm_opt/binomial_options.cpp +++ b/examples/bnm_opt/binomial_options.cpp @@ -219,6 +219,26 @@ extern "C" void binomialOptionsEntry(real *callValue, real *_T, size_t optN); +void compute_start_end(float g_start, + float g_end, + float *l_start, + float *l_end, + float num_fractions, + int fraction_id, + int id, + int total_size) +{ + + float range = g_end - g_start; + float chunk_width = range / (float)num_fractions; + *l_start = g_start + fraction_id * chunk_width; + *l_end = *l_start + chunk_width; + + float local_range = *l_end - *l_start; + float distributed_chunk = local_range / total_size; + *l_start = *l_start + distributed_chunk * id; + *l_end = *l_start + distributed_chunk; +} int main(int argc, char **argv) { @@ -232,7 +252,7 @@ int main(int argc, char **argv) MPI_Comm_rank(MPI_COMM_WORLD, &rank); #endif - if (argc != 3) { + if (argc != 5) { std::cout << "USAGE: " << argv[0] << " num-options batch_size"; return EXIT_FAILURE; } @@ -244,6 +264,25 @@ int main(int argc, char **argv) size_t numOptions = std::atoi(argv[1]); size_t batch_size = std::atoi(argv[2]); + float num_fractions = std::atof(argv[3]); + int fraction_id = std::atoi(argv[4]); + + float S_start; + float S_end; + compute_start_end( + 5, 30, &S_start, &S_end, num_fractions, fraction_id, rank, size); + + float X_start, X_end; + compute_start_end( + 1.0f, 100.0f, &X_start, &X_end, num_fractions, fraction_id, rank, size); + + float T_start, T_end; + compute_start_end( + 0.25f, 10.0f, &T_start, &T_end, num_fractions, fraction_id, rank, size); + + printf("Rank Id %d S-Start:%f S-End %f\n", rank, S_start, S_end); + printf("Rank Id %d X-Start:%f X-End %f\n", rank, X_start, X_end); + printf("Rank Id %d T-Start:%f T-End %f\n", rank, T_start, T_end); bool write_output = false; @@ -302,9 +341,9 @@ int main(int argc, char **argv) for (size_t i = 0; i < numOptions; i += batch_size) { for (int j = 0; j < std::min(numOptions - i * batch_size, batch_size); j++) { - S[j] = randData(5.0f, 30.0f); - X[j] = randData(1.0f, 100.0f); - T[j] = randData(0.25f, 10.0f); + S[j] = randData(S_start, S_end); + X[j] = randData(X_start, X_end); + T[j] = randData(T_start, T_end); R[j] = 0.06f; V[j] = 0.10f; BlackScholesCall(callValueBS[j], S[j], X[j], T[j], R[j], V[j]); diff --git a/examples/tools/deploy_sequential.py b/examples/tools/deploy_sequential.py index e37e882b..1926dfd8 100644 --- a/examples/tools/deploy_sequential.py +++ b/examples/tools/deploy_sequential.py @@ -9,303 +9,6 @@ from ams.store import AMSDataStore -from typing import Optional -from dataclasses import dataclass -from ams import util - - -def constuct_cli_cmd(executable, *args, **kwargs): - command = [executable] - for k, v in kwargs.items(): - command.append(str(k)) - command.append(str(v)) - - for a in args: - command.append(str(a)) - - return command - - -@dataclass(kw_only=True) -class AMSJobResources: - nodes: int - tasks_per_node: int - cores_per_task: int = 1 - exclusive: Optional[bool] = True - gpus_per_task: Optional[int] = 0 - - -class AMSJob: - """ - Class Modeling a Job scheduled by AMS. - """ - - @classmethod - def generate_formatting(self, store): - return {"AMS_STORE_PATH": store.root_path} - - def __init__( - self, - name, - executable, - environ={}, - resources=None, - stdout=None, - stderr=None, - ams_log=False, - cli_args=[], - cli_kwargs={}, - ): - self._name = name - self._executable = executable - self._resources = resources - self.environ = environ - self._stdout = stdout - self._stderr = stderr - self._cli_args = [] - self._cli_kwargs = {} - if cli_args is not None: - self._cli_args = list(cli_args) - if cli_kwargs is not None: - self._cli_kwargs = dict(cli_kwargs) - - def generate_cli_command(self): - return constuct_cli_cmd(self.executable, *self._cli_args, **self._cli_kwargs) - - def __str__(self): - data = {} - data["name"] = self._name - data["executable"] = self._executable - data["stdout"] = self._stdout - data["stderr"] = self._stderr - data["cli_args"] = self._cli_args - data["cli_kwargs"] = self._cli_kwargs - data["resources"] = self._resources - return f"{self.__class__.__name__}({data})" - - def precede_deploy(self, store): - pass - - @property - def resources(self): - """The resources property.""" - return self._resources - - @resources.setter - def resources(self, value): - self._resources = value - - @property - def executable(self): - """The executable property.""" - return self._executable - - @executable.setter - def executable(self, value): - self._executable = value - - @property - def environ(self): - """The environ property.""" - return self._environ - - @environ.setter - def environ(self, value): - if isinstance(value, type(os.environ)): - self._environ = dict(value) - return - elif not isinstance(value, dict) and value is not None: - raise RuntimeError(f"Unknwon type {type(value)} to set job environment") - - self._environ = value - - @property - def stdout(self): - """The stdout property.""" - return self._stdout - - @stdout.setter - def stdout(self, value): - self._stdout = value - - @property - def stderr(self): - """The stderr property.""" - return self._stderr - - @stderr.setter - def stderr(self, value): - self._stderr = value - - @property - def name(self): - """The name property.""" - return self._name - - @name.setter - def name(self, value): - self._name = value - - -class AMSDomainJob(AMSJob): - def _generate_ams_object(self, store): - ams_object = dict() - if self.stage_dir is None: - ams_object["db"] = {"fs_path": str(store.get_candidate_path()), "dbType": "hdf5"} - else: - ams_object["db"] = {"fs_path": self.stage_dir, "dbType": "hdf5"} - - ams_object["ml_models"] = dict() - ams_object["domain_models"] = dict() - - for i, name in enumerate(self.domain_names): - models = store.search(domain_name=name, entry="models", version="latest") - print(json.dumps(models, indent=6)) - # This is the case in which we do not have any model - # Thus we create a data gathering entry - if len(models) == 0: - model_entry = { - "uq_type": "random", - "model_path": "", - "uq_aggregate": "mean", - "threshold": 1, - "db_label": name, - } - else: - model = models[0] - model_entry = { - "uq_type": model["uq_type"], - "model_path": model["file"], - "uq_aggregate": "mean", - "threshold": model["threshold"], - "db_label": name, - } - - ams_object["ml_models"][f"model_{i}"] = model_entry - ams_object["domain_models"][name] = f"model_{i}" - return ams_object - - def __init__(self, domain_names, stage_dir, *args, **kwargs): - self.domain_names = domain_names - self.stage_dir = stage_dir - self._ams_object = None - self._ams_object_fn = None - super().__init__(*args, **kwargs) - - @classmethod - def from_descr(cls, stage_dir, descr): - domain_job_resources = AMSJobResources(**descr["resources"]) - return cls( - name=descr["name"], - stage_dir=stage_dir, - domain_names=descr["domain_names"], - environ=None, - resources=domain_job_resources, - **descr["cli"], - ) - - def precede_deploy(self, store): - self._ams_object = self._generate_ams_object(store) - tmp_path = util.mkdir(store.root_path, "tmp") - self._ams_object_fn = f"{tmp_path}/{util.get_unique_fn()}.json" - with open(self._ams_object_fn, "w") as fd: - json.dump(self._ams_object, fd) - self.environ["AMS_OBJECTS"] = str(self._ams_object_fn) - - -class AMSMLJob(AMSJob): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @classmethod - def from_descr(cls, store, descr): - formatting = AMSJob.generate_formatting(store) - resources = AMSJobResources(**descr["resources"]) - cli_kwargs = descr["cli"].get("cli_kwargs", None) - if cli_kwargs is not None: - for k, v in cli_kwargs.items(): - if isinstance(v, str): - cli_kwargs[k] = v.format(**formatting) - cli_args = descr["cli"].get("cli_args", None) - if cli_args is not None: - for i, v in enumerate(cli_args): - cli_args[i] = v.format(**formatting) - - return cls( - name=descr["name"], - environ=None, - stdout=descr["cli"].get("stdout", None), - stderr=descr["cli"].get("stderr", None), - executable=descr["cli"]["executable"], - resources=resources, - cli_kwargs=cli_kwargs, - cli_args=cli_args, - ) - - -class AMSMLTrainJob(AMSMLJob): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class AMSSubSelectJob(AMSMLJob): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class AMSFSStageJob(AMSJob): - def __init__( - self, - store_dir, - src_dir, - dest_dir, - resources, - environ=None, - stdout=None, - stderr=None, - prune_module_path=None, - prune_class=None, - cli_args=[], - cli_kwargs={}, - ): - _cli_args = list(cli_args) - _cli_args.append("--store") - _cli_kwargs = dict(cli_kwargs) - _cli_kwargs["--dest"] = dest_dir - _cli_kwargs["--src"] = src_dir - _cli_kwargs["--pattern"] = "*.h5" - _cli_kwargs["--db-type"] = "dhdf5" - _cli_kwargs["--mechanism"] = "fs" - _cli_kwargs["--policy"] = "process" - _cli_kwargs["--persistent-db-path"] = store_dir - _cli_kwargs["--src"] = src_dir - - if prune_module_path is not None: - assert Path(prune_module_path).exists(), "Module path to user pruner does not exist" - _cli_kwargs["--load"] = prune_module_path - _cli_kwargs["--class"] = prune_class - - super().__init__( - name="AMSStage", - executable="AMSDBStage", - environ=environ, - resources=resources, - stdout=stdout, - stderr=stderr, - cli_args=_cli_args, - cli_kwargs=_cli_kwargs, - ) - - @staticmethod - def resources_from_domain_job(domain_job): - return AMSJobResources( - nodes=domain_job.resources.nodes, - tasks_per_node=1, - cores_per_task=5, - exclusive=False, - gpus_per_task=domain_job.resources.gpus_per_task, - ) - class FluxJobStatus: """Simple class to get job info from active Flux handle""" @@ -387,19 +90,21 @@ def submit_ams_job( def submit_cb(fut, flux_handle, store, jobs): jobid = fut.get_id() + print(type(fut)) tmp = FluxJobStatus(flux_handle) result_fut = flux.job.result_async(flux_handle, jobid) result_fut.then(result_cb, flux_handle, store, jobs) def result_cb(fut, flux_handle, store, jobs): + print(type(fut)) job = fut.get_info() result = job.result.lower() tmp = FluxJobStatus(flux_handle) current_job = jobs.pop(0) print(f"{job.id}: {current_job.name } finished with {result} and returned {job.returncode}") print(current_job) - push_next_job(flux_handle, store, jobs) + push_next_job(flux_handle, submit_futurtore, jobs) def push_next_job(flux_handle, store, jobs): diff --git a/src/AMSWorkflow/ams/orchestrator.py b/src/AMSWorkflow/ams/orchestrator.py index c76fdc93..2acbc7d6 100644 --- a/src/AMSWorkflow/ams/orchestrator.py +++ b/src/AMSWorkflow/ams/orchestrator.py @@ -547,6 +547,7 @@ def __call__(self): elif request.is_process(): print("Received", json.dumps(request.data())) + class AMSFakeRMQUpdate: def __init__( self, @@ -571,8 +572,10 @@ def __call__(self): requests = json.load(fd) for r in requests: item = [r] + print(f"publishing {item}") producer.send_message(json.dumps(item)) + class AMSRMQMessagePrinter(RMQLoaderTask): """ A AMSJobReceiver receives job specifications for existing domains running on the diff --git a/src/AMSWorkflow/ams/rmq.py b/src/AMSWorkflow/ams/rmq.py index f59b7b9a..7408e696 100644 --- a/src/AMSWorkflow/ams/rmq.py +++ b/src/AMSWorkflow/ams/rmq.py @@ -715,6 +715,8 @@ def open(self): self.channel = self.connection.channel() result = self.channel.queue_declare(queue=self._publish_queue, exclusive=False) + # TODO: assert if publish_queue is different than method.queue. + # Verify if this is guaranteed by the RMQ specification. self._publish_queue = result.method.queue self._connected = True return self @@ -732,9 +734,10 @@ def send_message(self, message): self._num_sent_messages += 1 try: self.channel.basic_publish(exchange="", routing_key=self._publish_queue, body=message) - self._num_confirmed_messages += 1 except pika.exceptions.UnroutableError: print(f" [{self._num_sent_messages}] Message could not be confirmed") + else: + self._num_confirmed_messages += 1 @dataclass diff --git a/src/AMSWorkflow/ams/stage.py b/src/AMSWorkflow/ams/stage.py index 209e3918..88e07e38 100644 --- a/src/AMSWorkflow/ams/stage.py +++ b/src/AMSWorkflow/ams/stage.py @@ -24,7 +24,7 @@ from ams.config import AMSInstance from ams.faccessors import get_reader, get_writer from ams.monitor import AMSMonitor -from ams.rmq import AMSMessage, AsyncConsumer +from ams.rmq import AMSMessage, AsyncConsumer, AMSRMQConfiguration from ams.store import AMSDataStore from ams.util import get_unique_fn @@ -119,19 +119,21 @@ class ForwardTask(Task): callback: A callback to be applied on every message before pushing it to the next stage. """ - def __init__(self, i_queue, o_queue, callback): + def __init__(self, db_path, db_store, name, i_queue, o_queue, user_obj): """ initializes a ForwardTask class with the queues and the callback. """ - if not isinstance(callback, Callable): - raise TypeError(f"{callback} argument is not Callable") + + self._db_path = db_path + self._db_store = db_store + self._db_name = name self.i_queue = i_queue self.o_queue = o_queue - self.callback = callback + self.user_obj = user_obj self.datasize = 0 - def _action(self, data): + def _data_cb(self, data): """ Apply an 'action' to the incoming data @@ -141,12 +143,18 @@ def _action(self, data): Returns: A pair of inputs, outputs of the data after the transformation """ - inputs, outputs = self.callback(data.inputs, data.outputs) + inputs, outputs = self.user_obj.data_cb(data.inputs, data.outputs) # This can be too conservative, we may want to relax it later if not (isinstance(inputs, np.ndarray) and isinstance(outputs, np.ndarray)): raise TypeError(f"{self.callback.__name__} did not return numpy arrays") return inputs, outputs + def _model_update_cb(self, db, msg): + domain = msg["domain"] + model = db.search(domain, "models", version="latest") + _updated = self.user_obj.update_model_cb(domain, model) + print(f"Model update status: {_updated}") + @AMSMonitor(record=["datasize"]) def __call__(self): """ @@ -155,23 +163,26 @@ def __call__(self): the tasks waiting on the output queues about the terminations and returns from the function. """ - while True: - # This is a blocking call - item = self.i_queue.get(block=True) - if item.is_terminate(): - self.o_queue.put(QueueMessage(MessageType.Terminate, None)) - break - elif item.is_process(): - data = item.data() - inputs, outputs = self._action(data) - self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs, data.domain_name))) - self.datasize += inputs.nbytes + outputs.nbytes - elif item.is_delete(): - print(f"Sending Delete Message Type {self.__class__.__name__}") - self.o_queue.put(item) - elif item.is_new_model(): - # This is not handled yet - continue + with AMSDataStore(self.db_path, self.db_store, self.name) as db: + while True: + # This is a blocking call + item = self.i_queue.get(block=True) + if item.is_terminate(): + self.o_queue.put(QueueMessage(MessageType.Terminate, None)) + break + elif item.is_process(): + data = item.data() + inputs, outputs = self._data_cb(data) + self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs, data.domain_name))) + self.datasize += inputs.nbytes + outputs.nbytes + elif item.is_new_model(): + self._model_update_cb(db, data) + elif item.is_delete(): + print(f"Sending Delete Message Type {self.__class__.__name__}") + self.o_queue.put(item) + elif item.is_new_model(): + # This is not handled yet + continue return @@ -223,9 +234,9 @@ def __call__(self): print(f"Spend {end - start} at {self.__class__.__name__}") -class RMQLoaderTask(Task): +class RMQDomainDataLoaderTask(Task): """ - A RMQLoaderTask consumes data from RabbitMQ bundles the data of + A RMQDomainDataLoaderTask consumes 'AMSMessages' from RabbitMQ bundles the data of the files into batches and forwards them to the next task waiting on the output queuee. @@ -264,7 +275,7 @@ def __init__( # Signals can only be used within the main thread if self.policy != "thread": # We ignore SIGTERM, SIGUSR1, SIGINT by default so later - # we can override that handler only for RMQLoaderTask + # we can override that handler only for RMQDomainDataLoaderTask for s in self.signals: self.orig_sig_handlers[s] = signal.getsignal(s) signal.signal(s, signal.SIG_IGN) @@ -326,13 +337,42 @@ def __call__(self): """ Busy loop of consuming messages from RMQ queue """ - # Installing signal callbacks only for RMQLoaderTask + # Installing signal callbacks only for RMQDomainDataLoaderTask if self.policy != "thread": for s in self.signals: signal.signal(s, self.signal_wrapper(self.__class__.__name__, os.getpid())) self.rmq_consumer.run() +class RMQControlMessageTask(RMQDomainDataLoaderTask): + """ + A RMQControlMessageTask consumes JSON-messages from RabbitMQ and forwards them to + the o_queue of the pruning Task. + + Attributes: + o_queue: The output queue to write the transformed messages + credentials: A JSON file with the credentials to log on the RabbitMQ server. + certificates: TLS certificates + rmq_queue: The RabbitMQ queue to listen to. + prefetch_count: Number of messages prefected by RMQ (impact performance) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def callback_message(self, ch, basic_deliver, properties, body): + """ + Callback that will be called each time a message will be consummed. + the connection (or if a problem happened with the connection). + """ + start_time = time.time() + data = json.loads(body) + if data["request_type"] == "done-training": + self.o_queue.put(QueueMessage(MessageType.NewModel, data)) + + self.total_time += time.time() - start_time + + class FSWriteTask(Task): """ A Class representing a task flushing data in the specified output directory @@ -517,7 +557,7 @@ def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="hdf5") if stage_dir is not None: self.stage_dir = stage_dir - self.actions = list() + self.user_action = None self.db_type = db_type @@ -525,17 +565,21 @@ def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="hdf5") self.store = store - def add_data_action(self, callback): + def add_user_action(self, obj): """ Adds an action to be performed at the data before storing them in the filesystem Args: callback: A callback to be called on every input, output. """ - if not callable(callback): - raise TypeError(f"{self.__class__.__name__} requires a callable as an argument") - self.actions.append(callback) + if not (hasattr(obj, "data_cb") and callable(getattr(obj, "data_cb"))): + raise TypeError(f"User provided object {obj} does not have data_cb") + + if not (hasattr(obj, "update_model_cb") and callable(getattr(obj, "update_model_cb"))): + raise TypeError(f"User provided object {obj} does not have data_cb") + + self.user_action = obj def _seq_execute(self): """ @@ -593,12 +637,22 @@ def _link_pipeline(self, policy): # Every action requires 1 input and one output q. But the output # q is used as an inut q on the next action thus we need num actions -1. # 2 extra queues to store to data-store and publish on kosh - num_queues = 1 + len(self.actions) - 1 + 2 - self._queues = [_qType() for i in range(num_queues)] + self._queues = [_qType() for i in range(4)] self._tasks = [self.get_load_task(self._queues[0], policy)] - for i, a in enumerate(self.actions): - self._tasks.append(ForwardTask(self._queues[i], self._queues[i + 1], a)) + assert len(self.actions) == 1, "We only support a single user action" + self._tasks.append( + ForwardTask( + self.ams_config.db_path, + self.ams_config.db_store, + self.ams_config.name, + self._queues[0], + self._queues[1], + self.user_action, + ) + ) + if self.requires_model_update: + self._tasks.append(self.get_model_update_task(self._queues[0], policy)) # After user actions we store into a file self._tasks.append(FSWriteTask(self._queues[-2], self._queues[-1], self._writer, self.stage_dir)) @@ -622,6 +676,17 @@ def execute(self, policy): # Execute them self._execute_tasks(policy) + @abstractmethod + def requires_model_update(self): + """ + Returns whether the pipeline provides a model-update message parsing mechanism + """ + pass + + @abstractmethod + def get_model_update_task(self, o_queue, policy): + pass + @abstractmethod def get_load_task(self, o_queue, policy): """ @@ -731,6 +796,12 @@ def from_cli(cls, args): args.pattern, ) + def requires_model_update(self): + return False + + def get_model_update_task(self, o_queue, policy): + raise RuntimeError("FSPipeline does not support model update") + class RMQPipeline(Pipeline): """ @@ -746,7 +817,22 @@ class RMQPipeline(Pipeline): rmq_queue: The RMQ queue to listen to. """ - def __init__(self, db_dir, store, dest_dir, stage_dir, db_type, host, port, vhost, user, password, cert, rmq_queue): + def __init__( + self, + db_dir, + store, + dest_dir, + stage_dir, + db_type, + host, + port, + vhost, + user, + password, + cert, + data_queue, + model_update_queue=None, + ): """ Initialize a RMQPipeline that will write data to the 'dest_dir' and optionally publish these files to the kosh-store 'store' by using the stage_dir as an intermediate directory. @@ -758,7 +844,8 @@ def __init__(self, db_dir, store, dest_dir, stage_dir, db_type, host, port, vhos self._user = user self._password = password self._cert = Path(cert) - self._rmq_queue = rmq_queue + self._data_queue = data_queue + self._model_update_queue = model_update_queue def get_load_task(self, o_queue, policy): """ @@ -767,10 +854,32 @@ def get_load_task(self, o_queue, policy): Args: o_queue: The queue the load task will push read data. - Returns: An RMQLoaderTask instance reading data from the + Returns: An RMQDomainDataLoaderTask instance reading data from the filesystem and forwarding the values to the o_queue. """ - return RMQLoaderTask( + return RMQDomainDataLoaderTask( + o_queue, + self._host, + self._port, + self._vhost, + self._user, + self._password, + self._cert, + self._data_queue, + policy, + ) + + def get_model_update_task(self, o_queue, policy): + """ + Return a Task receives messages from the training job regarding the status of new models + + Args: + o_queue: The queue to push the model update message. + + Returns: An RMQControlMessageTask instance reading data from self._model_update_queue + and forwarding the values to the o_queue. + """ + return RMQControlMessageTask( o_queue, self._host, self._port, @@ -778,7 +887,7 @@ def get_load_task(self, o_queue, policy): self._user, self._password, self._cert, - self._rmq_queue, + self._model_update_queue, policy, ) @@ -788,9 +897,8 @@ def add_cli_args(parser): Add cli arguments to the parser required by this Pipeline. """ Pipeline.add_cli_args(parser) - parser.add_argument("-c", "--creds", help="Credentials file (JSON)", required=True) - parser.add_argument("-t", "--cert", help="TLS certificate file", required=True) - parser.add_argument("-q", "--queue", help="On which queue to receive messages", required=True) + parser.add_argument("-c", "--creds", help="AMS credentials file (JSON)", required=True) + parser.add_argument("-u", "--update-rmq-models", help="Update-rmq-models", action="store_true") return @classmethod @@ -799,13 +907,7 @@ def from_cli(cls, args): Create RMQPipeline from the user provided CLI. """ - # TODO: implement an interface so users can plug any parser for RMQ credentials - config = cls.parse_credentials(cls, args.creds) - host = config["service-host"] - port = config["service-port"] - vhost = config["rabbitmq-vhost"] - user = config["rabbitmq-user"] - password = config["rabbitmq-password"] + config = AMSRMQConfiguration.from_json(args.creds) return cls( args.persistent_db_path, @@ -813,22 +915,18 @@ def from_cli(cls, args): args.dest_dir, args.stage_dir, args.db_type, - host, - port, - vhost, - user, - password, - args.cert, - args.queue, + config.service_host, + config.service_port, + config.rabbitmq_vhost, + config.rabbimq_user, + config.rabbitmq_password, + config.rabbitmq_cert, + config.rabbitmq_inbound_queue, + config.rabbitmq_ml_status_queue if args.update_rmq_models else None, ) - @staticmethod - def parse_credentials(self, json_file: str) -> dict: - """Internal method to parse the credentials file""" - data = {} - with open(json_file, "r") as f: - data = json.load(f) - return data + def requires_model_update(self): + return self._model_update_queue != None def get_pipeline(src_mechanism="fs"):