Skip to content

Commit

Permalink
Fix recursive non-breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Dec 3, 2024
1 parent 7cde03f commit beb2c27
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 135 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
#### Changed defaults / behaviours

- Deprecate vanilla `DataType`
- Remove `_Encodable` from project
- Remove `_Encodable` from project

#### New Features & Functionality

Expand Down
44 changes: 44 additions & 0 deletions superduper/base/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rich.console import Console

from superduper import Component, logging
from superduper.backends.base.query import Query
from superduper.base.document import Document
from superduper.base.event import Create, Signal, Update
from superduper.components.component import Status
Expand Down Expand Up @@ -55,6 +56,7 @@ def apply(
context=object.uuid,
job_events={},
global_diff=diff,
non_breaking_changes={},
)
# this flags that the context is not needed anymore
if not create_events:
Expand Down Expand Up @@ -123,6 +125,7 @@ def apply(
def _apply(
db: 'Datalayer',
object: 'Component',
non_breaking_changes: t.Dict,
context: str | None = None,
job_events: t.Dict[str, 'Job'] | None = None,
parent: t.Optional[str] = None,
Expand Down Expand Up @@ -164,6 +167,7 @@ def wrapper(child):
job_events=job_events,
parent=object.uuid,
global_diff=global_diff,
non_breaking_changes=non_breaking_changes,
)

job_events.update(j)
Expand All @@ -179,6 +183,29 @@ def wrapper(child):

serialized = serialized.map(wrapper, lambda x: isinstance(x, Component))

def replace_existing(x):
if isinstance(x, str):
for uuid in non_breaking_changes:
x = x.replace(uuid, non_breaking_changes[uuid])

elif isinstance(x, Query):
r = x.dict()
for uuid in non_breaking_changes:
r['query'] = r['query'].replace(uuid, non_breaking_changes[uuid])
for i, doc in enumerate(r['documents']):
replace = {}
for k in doc:
replace_k = k
for uuid in non_breaking_changes:
replace_k = replace_k.replace(uuid, non_breaking_changes[uuid])
replace[replace_k] = doc[k]
r['documents'][i] = replace
x = Document.decode(r).unpack()

else:
raise TypeError("Unexpected target of substitution in db.apply")
return x

try:
current = db.load(object.type_id, object.identifier)

Expand All @@ -187,6 +214,10 @@ def wrapper(child):
current_serialized = current.dict(metadata=False, refs=True)
del current_serialized['uuid']

serialized = serialized.map(
replace_existing, lambda x: isinstance(x, str) or isinstance(x, Query)
)

# finds the fields where there is a difference
this_diff = Document(current_serialized, schema=current_serialized.schema).diff(
serialized
Expand All @@ -196,7 +227,11 @@ def wrapper(child):
if not this_diff:
# if no change then update the component
# to have the same info as the "existing" version

non_breaking_changes[object.uuid] = current.uuid

current.handle_update_or_same(object)

return create_events, job_events

elif set(this_diff.keys(deep=True)).intersection(object.breaks):
Expand Down Expand Up @@ -237,6 +272,15 @@ def wrapper(child):

else:
apply_status = 'update'

# the non-breaking changes lookup table
# allows components which are downstream
# from this component via references
# (e.g. `Listener` instances which listen to these outputs)
# to understand if they are now referring to the "original"
# or the "new" version.
non_breaking_changes[object.uuid] = current.uuid

current.handle_update_or_same(object)

serialized['version'] = current.version
Expand Down
20 changes: 20 additions & 0 deletions superduper/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def _set_upstream(self):
f'dependency {deps[it]}'
)

if not self.upstream:
return

from collections import defaultdict

from superduper import Component

# This is to perform deduplication, in case an upstream
# listener has already been provided

huuids = [x.huuid if isinstance(x, Component) else x for x in self.upstream]
huuids = defaultdict(lambda: [])
for x in self.upstream:
if isinstance(x, Component):
huuids['&:component:' + x.huuid].append(x)
else:
huuids[x].append(x)

self.upstream = [x[0] for x in huuids.values()]

@property
def predict_id(self):
"""Predict ID property."""
Expand Down
5 changes: 1 addition & 4 deletions superduper/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ def __post_init__(self, db, substitutions):
self.template = self.template.encode(defaults=True, metadata=False)
self.template = SuperDuperFlatEncode(self.template)
if substitutions is not None:
databackend_name = db.databackend._backend.__class__.__name__.split(
'DataBackend'
)[0].lower()
substitutions = {
databackend_name: 'databackend',
db.databackend.backend_name: 'databackend',
CFG.output_prefix: 'output_prefix',
**substitutions,
}
Expand Down
4 changes: 2 additions & 2 deletions superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def db_execute(query: t.Dict, db: 'Datalayer' = DatalayerDependency()):
return [{'_base': output}], []

if '_path' not in query:
plugin = db.databackend.type.__module__.split('.')[0]
query['_path'] = f'{plugin}.query.parse_query'
plugin = db.databackend.backend_name
query['_path'] = f'superduper_{plugin}.query.parse_query'

q = Document.decode(query, db=db).unpack()

Expand Down
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit beb2c27

Please sign in to comment.