diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 2e6edcde5d7..e2ba67192a7 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -163,8 +163,7 @@ class UserCodeStatusCollectionV1(SyncableSyftObject): # this is empty in the case of l0 status_dict: dict[ServerIdentity, tuple[UserCodeStatus, str]] = {} - - user_code_link: LinkedObject + user_code_link: LinkedObject[UserCode] @serializable() @@ -434,7 +433,7 @@ class UserCodeV1(SyncableSyftObject): user_unique_func_name: str code_hash: str signature: inspect.Signature - status_link: LinkedObject | None = None + status_link: LinkedObject[UserCodeStatusCollection] | None = None input_kwargs: list[str] submit_time: DateTime | None = None # tracks if the code calls datasite.something, variable is set during parsing diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index cdf567a5955..31c782d8245 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -1,5 +1,6 @@ # stdlib from typing import ClassVar +from typing import TYPE_CHECKING # third party from pydantic import model_validator @@ -26,15 +27,21 @@ from ..user.user_roles import GUEST_ROLE_LEVEL +if TYPE_CHECKING: + # relative + from ..code.user_code import UserCode + from ..job.job_stash import Job + + @serializable() class ExecutionOutput(SyncableSyftObject): __canonical_name__ = "ExecutionOutput" __version__ = SYFT_OBJECT_VERSION_1 executing_user_verify_key: SyftVerifyKey - user_code_link: LinkedObject + user_code_link: "LinkedObject[UserCode]" output_ids: list[UID] | dict[str, UID] | None = None - job_link: LinkedObject | None = None + job_link: "LinkedObject[Job] | None" = None created_at: DateTime = DateTime.now() input_ids: dict[str, UID] | None = None diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 5ed21f5007d..17513bc7e32 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -264,14 +264,14 @@ class ProjectRequest(ProjectEventAddObject): __canonical_name__ = "ProjectRequest" __version__ = SYFT_OBJECT_VERSION_1 - linked_request: LinkedObject + linked_request: LinkedObject[Request] allowed_sub_types: list[type] = [ProjectRequestResponse] @field_validator("linked_request", mode="before") @classmethod - def _validate_linked_request(cls, v: Any) -> LinkedObject: + def _validate_linked_request(cls, v: Any) -> LinkedObject[Request]: if isinstance(v, Request): - linked_request = LinkedObject.from_obj(v, server_uid=v.server_uid) + linked_request = LinkedObject[Request].from_obj(v, server_uid=v.server_uid) linked_request.syft_server_location = v.syft_server_location return linked_request elif isinstance(v, LinkedObject): @@ -1028,7 +1028,9 @@ def add_request( self, request: Request, ) -> SyftSuccess: - linked_request = LinkedObject.from_obj(request, server_uid=request.server_uid) + linked_request = LinkedObject[Request].from_obj( + request, server_uid=request.server_uid + ) request_event = ProjectRequest(linked_request=linked_request) self.add_event(request_event) return SyftSuccess(message="Request created successfully") diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index c43caa302be..21fdc16829f 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -21,6 +21,7 @@ from ...types.transforms import TransformContext from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission +from ..worker.worker_pool import WorkerPool __all__ = ["QueueItem"] @@ -77,7 +78,7 @@ class QueueItem(SyftObject): job_id: UID | None = None worker_settings: WorkerSettings | None = None has_execute_permissions: bool = False - worker_pool: LinkedObject + worker_pool: LinkedObject[WorkerPool] def __repr__(self) -> str: return f": {self.status}" diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 7c5495a756e..e2e8238530b 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -108,7 +108,7 @@ class ActionStoreChange(Change): __canonical_name__ = "ActionStoreChange" __version__ = SYFT_OBJECT_VERSION_1 - linked_obj: LinkedObject + linked_obj: LinkedObject[ActionObject] apply_permission_type: ActionPermission __repr_attrs__ = ["linked_obj", "apply_permission_type"] @@ -1370,8 +1370,8 @@ class UserCodeStatusChange(Change): __version__ = SYFT_OBJECT_VERSION_1 value: UserCodeStatus - linked_obj: LinkedObject - linked_user_code: LinkedObject + linked_obj: LinkedObject[UserCodeStatusCollection] + linked_user_code: LinkedObject[UserCode] nested_solved: bool = False match_type: bool = True __repr_attrs__ = [ @@ -1523,7 +1523,7 @@ def link(self) -> SyftObject | None: class SyncedUserCodeStatusChange(UserCodeStatusChange): __canonical_name__ = "SyncedUserCodeStatusChange" __version__ = SYFT_OBJECT_VERSION_1 - linked_obj: LinkedObject | None = None # type: ignore + linked_obj: LinkedObject[UserCodeStatusCollection] | None = None # type: ignore @property def approved(self) -> bool: diff --git a/packages/syft/src/syft/service/sync/sync_state.py b/packages/syft/src/syft/service/sync/sync_state.py index 0e00070fe64..1c5559c32bf 100644 --- a/packages/syft/src/syft/service/sync/sync_state.py +++ b/packages/syft/src/syft/service/sync/sync_state.py @@ -121,7 +121,7 @@ class SyncState(SyftObject): objects: dict[UID, SyncableSyftObject] = {} dependencies: dict[UID, list[UID]] = {} created_at: DateTime = Field(default_factory=DateTime.now) - previous_state_link: LinkedObject | None = None + previous_state_link: "LinkedObject[SyncState] | None" = None permissions: dict[UID, set[str]] = {} storage_permissions: dict[UID, set[UID]] = {} ignored_batches: dict[UID, int] = {} diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 2343dc0b9a6..4f7853c24e9 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -1,6 +1,10 @@ # stdlib import logging from typing import Any +from typing import Generic +from typing import TypeVar +from typing import Union +from typing import get_args # third party from typing_extensions import Self @@ -15,19 +19,21 @@ from ..types.result import as_result from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SyftObject +from ..types.syft_object import SyftObjectVersioned from ..types.uid import UID +T = TypeVar("T", bound=SyftObject) logger = logging.getLogger(__name__) @serializable() -class LinkedObject(SyftObject): +class LinkedObject(SyftObjectVersioned, Generic[T]): __canonical_name__ = "LinkedObject" __version__ = SYFT_OBJECT_VERSION_1 server_uid: UID service_type: type[Any] - object_type: type[SyftObject] + object_type: type[T] object_uid: UID _resolve_cache: SyftObject | None = None @@ -40,6 +46,15 @@ def __str__(self) -> str: ) return f"{resolved_obj_type.__name__}: {self.object_uid} @ Server {self.server_uid}" + @classmethod + def get_generic_type(cls: type[Self]) -> type[T]: + args = cls.__pydantic_generic_metadata__["args"] + if len(args) != 1: + raise ValueError( + "Cannot infer LinkedObject type, generic argument not provided" + ) + return args[0] # type: ignore + @property def resolve(self) -> SyftObject: return self._resolve() @@ -105,10 +120,10 @@ def update_with_context( @classmethod def from_obj( cls, - obj: SyftObject | type[SyftObject], + obj: T | type[T], service_type: type[Any] | None = None, server_uid: UID | None = None, - ) -> Self: + ) -> "LinkedObject[T]": # type: ignore if service_type is None: # relative from ..service.action.action_object import ActionObject @@ -129,7 +144,7 @@ def from_obj( if server_uid is None: raise Exception(f"{cls} Requires an object UID") - return LinkedObject( + return LinkedObject[type(obj)]( # type: ignore server_uid=server_uid, service_type=service_type, object_type=type(obj), @@ -140,11 +155,11 @@ def from_obj( @classmethod def with_context( cls, - obj: SyftObject, + obj: T, context: ServerServiceContext, object_uid: UID | None = None, service_type: type[Any] | None = None, - ) -> Self: + ) -> "LinkedObject[T]": if service_type is None: # relative from ..service.service import TYPE_TO_SERVICE @@ -160,7 +175,7 @@ def with_context( raise ValueError(f"context {context}'s server is None") server_uid = context.server.id - return LinkedObject( + return LinkedObject[type(obj)]( # type: ignore server_uid=server_uid, service_type=service_type, object_type=type(obj), @@ -171,13 +186,85 @@ def with_context( def from_uid( cls, object_uid: UID, - object_type: type[SyftObject], + object_type: type[T], service_type: type[Any], server_uid: UID, - ) -> Self: - return cls( + ) -> "LinkedObject[T]": + return cls[object_type]( # type: ignore server_uid=server_uid, service_type=service_type, object_type=object_type, object_uid=object_uid, ) + + +def _unwrap_optional(type_: Any) -> Any: + try: + if type_ | None == type_: + args = get_args(type_) + return Union[tuple(arg for arg in args if arg != type(None))] # noqa + return type_ + except Exception: + return type_ + + +def _annotation_issubclass(type_: Any, cls: type) -> bool: + try: + return issubclass(type_, cls) + except Exception: + return False + + +def _resolve_syftobject_forward_refs(raise_errors: bool = False) -> None: + # relative + from ..types.syft_object_registry import SyftObjectRegistry + + type_names = [ + t.__name__ for t in SyftObjectRegistry.__type_to_canonical_name__.keys() + ] + if len(type_names) != len(set(type_names)): + raise ValueError( + "Duplicate names in SyftObjectRegistry, cannot resolve forward references" + ) + + types_namespace = { + k.__name__: k for k in SyftObjectRegistry.__type_to_canonical_name__.keys() + } + syft_objects = [v for v in types_namespace.values() if issubclass(v, SyftObject)] + + for so in syft_objects: + so.model_rebuild(raise_errors=raise_errors, _types_namespace=types_namespace) + + +def find_unannotated_linked_objects() -> None: + # Utility method to find LinkedObjects that are not annotated with a generic type + + # relative + from ..types.syft_object_registry import SyftObjectRegistry + + # Need to resolve forward references to find LinkedObjects + _resolve_syftobject_forward_refs() + + annotated = [] + unannotated = [] + + for cls in SyftObjectRegistry.__type_to_canonical_name__.keys(): + if not issubclass(cls, SyftObject): + continue + + for name, field in cls.model_fields.items(): + type_ = _unwrap_optional(field.annotation) + if _annotation_issubclass(type_, LinkedObject): + try: + type_.get_generic_type() + annotated.append((cls, name)) + except Exception: + unannotated.append((cls, name)) + + print("Annotated LinkedObjects:") + for cls, name in annotated: + print(f"{cls.__name__}.{name}") + + print("\n\nUnannotated LinkedObjects:") + for cls, name in unannotated: + print(f"{cls.__name__}.{name}")