From 65e29849413994a41136fcd1d060ec891437ea3a Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Wed, 26 Jun 2024 15:02:21 +0800 Subject: [PATCH] Fixed listener dependencies (#2222) --- CHANGELOG.md | 1 + superduperdb/backends/sqlalchemy/metadata.py | 3 +-- superduperdb/components/listener.py | 10 ++++++---- test/unittest/component/test_listener.py | 12 ++++++------ 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a6f6bf96..681b108fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,6 +116,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix ibis cdc and cdc config - Fixed 'objectmodel' and 'predict_one' in doc. - Fixed ray dependencies bug. +- Fixed listener dependencies bug. - Fixed cluster bug. ## [0.1.3](https://github.com/SuperDuperDB/superduperdb/compare/0.1.1...0.1.3]) (2024-Jun-20) diff --git a/superduperdb/backends/sqlalchemy/metadata.py b/superduperdb/backends/sqlalchemy/metadata.py index 9c5cf340b..45f6a47af 100644 --- a/superduperdb/backends/sqlalchemy/metadata.py +++ b/superduperdb/backends/sqlalchemy/metadata.py @@ -519,8 +519,7 @@ def show_jobs( # Execute the query and collect results res = self.query_results(self.job_table, stmt, session) - # Return the list of identifiers - return [r['identifier'] for r in res] + return res def update_job(self, job_id: str, key: str, value: t.Any): """Update the job with the given key and value. diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index 234dc6197..d5f504d6b 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -11,7 +11,7 @@ from superduperdb.misc.server import request_server from ..jobs.job import Job -from .component import Component, ComponentTuple +from .component import Component from .model import Model, ModelInputType if t.TYPE_CHECKING: @@ -121,7 +121,7 @@ def create_output_dest(cls, db: "Datalayer", uuid, model: Model): db.add(output_table) @property - def dependencies(self) -> t.List[ComponentTuple]: + def dependencies(self): """Listener model dependencies.""" args, kwargs = self.mapping.mapping all_ = list(args) + list(kwargs.values()) @@ -168,8 +168,10 @@ def schedule_jobs( assert not isinstance(self.model, str) dependencies_ids = [] - for model_name in self.dependencies: - jobs = self.db.metadata.show_jobs(str(model_name), 'model') or [] + for predict_id in self.dependencies: + upstream_listener = db.load(uuid=predict_id) + upstream_model = upstream_listener.model + jobs = self.db.metadata.show_jobs(upstream_model.identifier, 'model') or [] job_ids = [job['job_id'] for job in jobs] dependencies_ids.extend(job_ids) diff --git a/test/unittest/component/test_listener.py b/test/unittest/component/test_listener.py index ac5a40035..cf7e5132f 100644 --- a/test/unittest/component/test_listener.py +++ b/test/unittest/component/test_listener.py @@ -67,23 +67,23 @@ def insert_random(): listener2 = Listener( model=m2, - select=MongoQuery(table='_outputs.listener1::0').find(), - key='_outputs.listener1::0', + select=MongoQuery(table=listener1.outputs).find(), + key=listener1.outputs, identifier='listener2', ) db.add(listener1) db.add(listener2) - docs = list(db.execute(MongoQuery(table='_outputs.listener1::0').find({}))) + docs = list(db.execute(MongoQuery(table=listener1.outputs).find({}))) - assert all("listener2::0" in r["_outputs"] for r in docs) + assert all(listener2.predict_id in r["_outputs"] for r in docs) insert_random() - docs = list(db.execute(MongoQuery(table="_outputs.listener1::0").find({}))) + docs = list(db.execute(MongoQuery(table=listener1.outputs).find({}))) - assert all(["listener2::0" in d["_outputs"] for d in docs]) + assert all([listener2.predict_id in d["_outputs"] for d in docs]) @pytest.mark.parametrize(