Skip to content

Commit

Permalink
Fixed listener dependencies (#2222)
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou authored Jun 26, 2024
1 parent 396580b commit 65e2984
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix ibis cdc and cdc config
- Fixed 'objectmodel' and 'predict_one' in doc.
- Fixed ray dependencies bug.
- Fixed listener dependencies bug.
- Fixed cluster bug.

## [0.1.3](https://github.com/SuperDuperDB/superduperdb/compare/0.1.1...0.1.3]) (2024-Jun-20)
Expand Down
3 changes: 1 addition & 2 deletions superduperdb/backends/sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,7 @@ def show_jobs(
# Execute the query and collect results
res = self.query_results(self.job_table, stmt, session)

# Return the list of identifiers
return [r['identifier'] for r in res]
return res

def update_job(self, job_id: str, key: str, value: t.Any):
"""Update the job with the given key and value.
Expand Down
10 changes: 6 additions & 4 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from superduperdb.misc.server import request_server

from ..jobs.job import Job
from .component import Component, ComponentTuple
from .component import Component
from .model import Model, ModelInputType

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -121,7 +121,7 @@ def create_output_dest(cls, db: "Datalayer", uuid, model: Model):
db.add(output_table)

@property
def dependencies(self) -> t.List[ComponentTuple]:
def dependencies(self):
"""Listener model dependencies."""
args, kwargs = self.mapping.mapping
all_ = list(args) + list(kwargs.values())
Expand Down Expand Up @@ -168,8 +168,10 @@ def schedule_jobs(
assert not isinstance(self.model, str)

dependencies_ids = []
for model_name in self.dependencies:
jobs = self.db.metadata.show_jobs(str(model_name), 'model') or []
for predict_id in self.dependencies:
upstream_listener = db.load(uuid=predict_id)
upstream_model = upstream_listener.model
jobs = self.db.metadata.show_jobs(upstream_model.identifier, 'model') or []
job_ids = [job['job_id'] for job in jobs]
dependencies_ids.extend(job_ids)

Expand Down
12 changes: 6 additions & 6 deletions test/unittest/component/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,23 @@ def insert_random():

listener2 = Listener(
model=m2,
select=MongoQuery(table='_outputs.listener1::0').find(),
key='_outputs.listener1::0',
select=MongoQuery(table=listener1.outputs).find(),
key=listener1.outputs,
identifier='listener2',
)

db.add(listener1)
db.add(listener2)

docs = list(db.execute(MongoQuery(table='_outputs.listener1::0').find({})))
docs = list(db.execute(MongoQuery(table=listener1.outputs).find({})))

assert all("listener2::0" in r["_outputs"] for r in docs)
assert all(listener2.predict_id in r["_outputs"] for r in docs)

insert_random()

docs = list(db.execute(MongoQuery(table="_outputs.listener1::0").find({})))
docs = list(db.execute(MongoQuery(table=listener1.outputs).find({})))

assert all(["listener2::0" in d["_outputs"] for d in docs])
assert all([listener2.predict_id in d["_outputs"] for d in docs])


@pytest.mark.parametrize(
Expand Down

0 comments on commit 65e2984

Please sign in to comment.