Skip to content

Commit

Permalink
Merge pull request #9339 from OpenMined/fix_status_l0
Browse files Browse the repository at this point in the history
Fix status for requests high side after syncing
  • Loading branch information
teo-milea authored Oct 8, 2024
2 parents 46c29bc + 69c35a2 commit 99e6659
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 19 deletions.
18 changes: 18 additions & 0 deletions notebooks/scenarios/bigquery/sync/040-do-review-requests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,24 @@
"widget._sync_all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Check requests status on the high side"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for job in submitted_jobs_data_should_succeed:\n",
" request = get_request_for_job_info(all_requests, job)\n",
" assert request.status == RequestStatus.APPROVED"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
6 changes: 4 additions & 2 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,13 @@ def handle_sync_batch(

logger.debug(f"Decision: Syncing {len(sync_instructions)} objects")

# Apply empty state to source side to signal that we are done syncing
src_client.apply_state(src_resolved_state)
# Apply sync instructions to target side
for sync_instruction in sync_instructions:
tgt_resolved_state.add_sync_instruction(sync_instruction)
src_resolved_state.add_sync_instruction(sync_instruction)
# Apply empty state to source side to signal that we are done syncing
# We also add permissions for users from the low side to mark L0 request as approved
src_client.apply_state(src_resolved_state)
return tgt_client.apply_state(tgt_resolved_state)


Expand Down
62 changes: 49 additions & 13 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,8 @@ class SyncInstruction(SyftObject):

diff: ObjectDiff
decision: SyncDecision | None
new_permissions_lowside: list[ActionObjectPermission]
new_permissions_lowside: dict[type, list[ActionObjectPermission]]
new_permissions_highside: dict[type, list[ActionObjectPermission]]
new_storage_permissions_lowside: list[StoragePermission]
new_storage_permissions_highside: list[StoragePermission]
unignore: bool = False
Expand All @@ -1575,8 +1576,8 @@ def from_batch_decision(
share_to_user: SyftVerifyKey | None,
) -> Self:
# read widget state
new_permissions_low_side = []

new_permissions_low_side = {}
new_permissions_high_side = {}
# read permissions
if sync_direction == SyncDirection.HIGH_TO_LOW:
# To create read permissions for the object
Expand All @@ -1592,13 +1593,27 @@ def from_batch_decision(
"share_to_user is required to share private data"
)
else:
new_permissions_low_side = [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
new_permissions_low_side = {
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}
if diff.obj_type in [Job, SyftLog, Request] or issubclass(
diff.obj_type, ActionObject
):
new_permissions_high_side = {
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}

# storage permissions
new_storage_permissions = []
Expand All @@ -1620,6 +1635,7 @@ def from_batch_decision(
diff=diff,
decision=decision,
new_permissions_lowside=new_permissions_low_side,
new_permissions_highside=new_permissions_high_side,
new_storage_permissions_lowside=new_storage_permissions,
new_storage_permissions_highside=new_storage_permissions,
mockify=mockify,
Expand All @@ -1634,7 +1650,7 @@ class ResolvedSyncState(SyftObject):
create_objs: list[SyncableSyftObject] = []
update_objs: list[SyncableSyftObject] = []
delete_objs: list[SyftObject] = []
new_permissions: list[ActionObjectPermission] = []
new_permissions: dict[type, list[ActionObjectPermission]] = {}
new_storage_permissions: list[StoragePermission] = []
ignored_batches: dict[UID, int] = {} # batch root uid -> hash of the batch
unignored_batches: set[UID] = set()
Expand Down Expand Up @@ -1666,7 +1682,10 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
if sync_instruction.unignore:
self.unignored_batches.add(sync_instruction.batch_diff.root_id)

if diff.status == "SAME":
if (
diff.status == "SAME"
and len(sync_instruction.new_permissions_highside) == 0
):
return

my_obj = diff.low_obj if self.alias == "low" else diff.high_obj
Expand Down Expand Up @@ -1695,11 +1714,28 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
self.delete_objs.append(my_obj)

if self.alias == "low":
self.new_permissions.extend(sync_instruction.new_permissions_lowside)
for obj_type in sync_instruction.new_permissions_lowside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_lowside[obj_type]
)
else:
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_lowside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_lowside
)
elif self.alias == "high":
for obj_type in sync_instruction.new_permissions_highside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_highside[obj_type]
)
else:
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_highside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_highside
)
Expand Down
33 changes: 29 additions & 4 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ..code.user_code import UserCodeStatusCollection
from ..context import AuthedServiceContext
from ..job.job_stash import Job
from ..log.log import SyftLog
from ..request.request import Request
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import TYPE_TO_SERVICE
Expand Down Expand Up @@ -189,14 +191,38 @@ def sync_items(
self,
context: AuthedServiceContext,
items: list[SyncableSyftObject],
permissions: list[ActionObjectPermission],
permissions: dict[type, list[ActionObjectPermission]],
storage_permissions: list[StoragePermission],
ignored_batches: dict[UID, int],
unignored_batches: set[UID],
) -> SyftSuccess:
permissions_dict = defaultdict(list)
for permission in permissions:
permissions_dict[permission.uid].append(permission)
for permission_list in permissions.values():
for permission in permission_list:
permissions_dict[permission.uid].append(permission)

item_ids = [item.id.id for item in items]

# If we just want to add permissions without having an object
# This should happen only for the high side when we sync results but
# we need to add permissions for the DS to properly show the status of the requests
for obj_type, permission_list in permissions.items():
for permission in permission_list:
if permission.uid in item_ids:
continue
if obj_type not in [Job, SyftLog, Request] and not issubclass(
obj_type, ActionObject
):
raise SyftException(
public_message="Permission for object type not supported!"
)
if issubclass(obj_type, ActionObject):
store = context.server.services.action.stash
else:
service = context.server.get_service(TYPE_TO_SERVICE[obj_type])
store = service.stash # type: ignore[assignment]
if permission.permission == ActionPermission.READ:
store.add_permission(permission)

storage_permissions_dict = defaultdict(list)
for storage_permission in storage_permissions:
Expand All @@ -213,7 +239,6 @@ def sync_items(
else:
item = self.transform_item(context, item) # type: ignore[unreachable]
self.set_object(context, item).unwrap()

self.add_permissions_for_item(context, item, new_permissions)
self.add_storage_permissions_for_item(
context, item, new_storage_permissions
Expand Down

0 comments on commit 99e6659

Please sign in to comment.