diff --git a/src/nnsight/__init__.py b/src/nnsight/__init__.py index aa33a181..d3ee2eb1 100644 --- a/src/nnsight/__init__.py +++ b/src/nnsight/__init__.py @@ -1,4 +1,4 @@ -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # :::: ::: :::: ::: :::::::: ::::::::::: :::::::: ::: ::: ::::::::::: ::::::: :::::::: # # :+:+: :+: :+:+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: # @@ -8,10 +8,10 @@ # #+# #+#+# #+# #+#+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# # # ### #### ### #### ######## ########### ######## ### ### ### ####### ### ######## # # # -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # import os from functools import wraps -from typing import Dict, Union +from typing import Callable, Dict, Union import torch import yaml @@ -49,11 +49,11 @@ from torch._subclasses.fake_tensor import FakeTensor -def _bool(self): +def fake_bool(self): return True -DEFAULT_PATCHER.add(Patch(FakeTensor, _bool, "__bool__")) +DEFAULT_PATCHER.add(Patch(FakeTensor, fake_bool, "__bool__")) def fake_tensor_new_wrapper(fn): @@ -111,10 +111,11 @@ def noop(input: torch.Tensor, *args, **kwargs): ) import warnings + _str = str +_bool = bool + try: - - from torch.amp.autocast_mode import autocast, is_autocast_available @@ -548,3 +549,88 @@ def set_module_tensor_to_device( apply = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply log = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.log cond = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.cond + +import inspect + +from . import util +from .intervention import InterventionProxy + + +def trace(fn: Callable): + """Helper decorator to add a function to the intervention graph via `.apply(...)`. + This is opposed to entering the function during tracing and tracing all inner operations. + + Args: + fn (Callable): Function to apply. + + Returns: + Callable: Traceable function. + """ + + @wraps(fn) + def inner(*args, **kwargs): + + return apply(fn, *args, **kwargs) + + return inner + + +def local(object: Callable | InterventionProxy): + """Helper decorator to add a function to the intervention graph via `.apply(...)` + AND convert all input Proxies to local ones via `.local()`. + + If a non-function is passed in, its assumed to be an `InterventionProxy` and `.local()` is called and returned. + + Args: + object ( Callable | InterventionProxy): Function to apply or Proxy to make local. + + Returns: + Callable | InterventionProxy: Traceable local function or local Proxy. + """ + + if inspect.isroutine(object): + + fn = trace(object) + + @wraps(fn) + def inner(*args, **kwargs): + + args, kwargs = util.apply( + (args, kwargs), lambda x: x.local(), InterventionProxy + ) + + return fn(*args, **kwargs) + + return inner + + return object.local() + + +def remote(object: Callable | Any): + """Helper decorator to add a function to the intervention graph via `.apply(...)` + AND convert all input Proxies to downloaded local ones via `.local()` + AND convert the output to an uploaded remote one via `remote()`. + + If a non-function is passed in, `remote(object)` is called and returned. + + Args: + object ( Callable | Any): Function to apply or object to make remote. + + Returns: + Callable | InterventionProxy: Traceable local -> remote function or remote Proxy. + """ + + if inspect.isroutine(object): + + fn = local(object) + + @wraps(fn) + def inner(*args, **kwargs): + + return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote( + fn(*args, **kwargs) + ) + + return inner + + return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(object) diff --git a/src/nnsight/contexts/GraphBasedContext.py b/src/nnsight/contexts/GraphBasedContext.py index 2537343f..1945be65 100755 --- a/src/nnsight/contexts/GraphBasedContext.py +++ b/src/nnsight/contexts/GraphBasedContext.py @@ -136,6 +136,18 @@ def log(self, *data: Any) -> None: data (Any): Data to print. """ self.apply(print, *data) + + def remote(self, data:Any) -> InterventionProxy: + """Streams data remotely when it becomes available locally. + The remote service will block until the local value is uploaded and received. + + Is a no-op when not executing remotely. + + Returns: + InterventionProxy: Proxy. + """ + + return protocols.StreamingUploadProtocol.add(self.graph, data) def bool(self, *args, **kwargs) -> InterventionProxy: """NNsight helper method to create a traceable bool.""" diff --git a/src/nnsight/contexts/Tracer.py b/src/nnsight/contexts/Tracer.py index 1fb24fa6..8d4d55bd 100755 --- a/src/nnsight/contexts/Tracer.py +++ b/src/nnsight/contexts/Tracer.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from ..models.mixins import RemoteableMixin from ..models.NNsightModel import NNsight + from ..tracing.Node import Node class Tracer(GraphBasedContext, RemoteMixin, BridgeMixin, EditMixin): @@ -179,6 +180,9 @@ def remote_backend_handle_result_value(self, value: Dict[str, Any]) -> None: # TODO : graph mismatch handle. hash json ? for node_name, node_value in value.items(): self.graph.nodes[node_name]._value = node_value + + def remote_backend_get_stream_node(self, name: str, graph_id: str) -> "Node": + return self.graph.nodes[name] def remote_backend_cleanup(self): diff --git a/src/nnsight/contexts/backends/RemoteBackend.py b/src/nnsight/contexts/backends/RemoteBackend.py index 38840964..212b7b22 100644 --- a/src/nnsight/contexts/backends/RemoteBackend.py +++ b/src/nnsight/contexts/backends/RemoteBackend.py @@ -1,7 +1,8 @@ from __future__ import annotations import io -from typing import TYPE_CHECKING, Any, Callable +import weakref +from typing import TYPE_CHECKING, Any, Callable, Tuple import requests import socketio @@ -10,12 +11,14 @@ from ... import CONFIG from ...logger import logger, remote_logger +from ...tracing import protocols from .LocalBackend import LocalBackend, LocalMixin if TYPE_CHECKING: from ...schema.Request import RequestModel from ...schema.Response import ResponseModel + from ...tracing.Node import Node class RemoteMixin(LocalMixin): @@ -53,6 +56,30 @@ def remote_backend_handle_result_value(self, value: Any) -> None: raise NotImplementedError() + # Following two methods are really only necessary because how you get a node in Tracer is different than Session + # due to one have many graphs and the other on one. + def remote_backend_get_stream_node(self, *args) -> "Node": + """Get streaming node on the client side based on arguments returned from `RemoteMixin.remote_stream_format` + + Returns: + Node: Streaming node on the client side. + """ + + raise NotImplementedError() + + @classmethod + def remote_stream_format(self, node: Node) -> Tuple[Any]: + """Returns arguments needed to get the correct streaming node on the client side. + + Args: + node (Node): Streaming node on the server side. + + Returns: + Any: Arguments + """ + + return node.name, node.graph.id + def remote_backend_cleanup(self): raise NotImplementedError() @@ -80,28 +107,36 @@ def __init__( self.ssl = CONFIG.API.SSL if ssl is None else ssl self.api_key = api_key or CONFIG.API.APIKEY self.blocking = blocking - self.handle_result = None self.host = host or CONFIG.API.HOST self.address = f"http{'s' if self.ssl else ''}://{self.host}" self.ws_address = f"ws{'s' if CONFIG.API.SSL else ''}://{self.host}" - def request(self, obj: RemoteMixin): + self.object: RemoteMixin = None + + def request(self) -> "RequestModel": + """Gets RequestModel based on intervention object. + + Returns: + RequestModel: RequestModel + """ - model_key = obj.remote_backend_get_model_key() + model_key = self.object.remote_backend_get_model_key() from ...schema.Request import RequestModel # Create request using pydantic to parse the object itself. - return RequestModel(object=obj, model_key=model_key) + return RequestModel(object=self.object, model_key=model_key) - def __call__(self, obj: RemoteMixin): + def __call__(self, object: RemoteMixin): - self.handle_result = obj.remote_backend_handle_result_value + # We need to reference the object's RemoteMixin methods so we need to access it. + # Make sure its weak reference to avoid reference loops on a potentially large object. + self.object = weakref.proxy(object) if self.blocking: - request = self.request(obj) + request = self.request() # Do blocking request. self.blocking_request(request) @@ -110,25 +145,27 @@ def __call__(self, obj: RemoteMixin): request = None + # If self.job_id is empty, it means were sending a new job. if not self.job_id: - request = self.request(obj) + request = self.request() + + # Otherwise we are getting the status / result of the existing job. + self.non_blocking_request(request) - self.non_blocking_request(request=request) + # Cleanup + self.object.remote_backend_cleanup() - obj.remote_backend_cleanup() - - def handle_response(self, data: Any) -> "ResponseModel": + def handle_response(self, response: "ResponseModel") -> None: """Handles incoming response data. - Parses it into the `ResponseModel` pydantic object. Logs the response object. If the job is completed, retrieve and stream the result from the remote endpoint. Use torch.load to decode and load the `ResultModel` into memory. Use the backend object's .handle_result method to handle the decoded result. Args: - data (Any): Json data to concert to `ResponseModel` + response (Any): Json data to concert to `ResponseModel` Raises: Exception: If the job's status is `ResponseModel.JobStatus.ERROR` @@ -139,59 +176,71 @@ def handle_response(self, data: Any) -> "ResponseModel": from ...schema.Response import ResponseModel, ResultModel - # Load the data into the ResponseModel pydantic class. - response = ResponseModel(**data) - # Log response for user - remote_logger.info(str(response)) + response.log(remote_logger) - # If the status of the response is completed, update the local nodes that the user specified to save. - # Then disconnect and continue. + # If job is completed: if response.status == ResponseModel.JobStatus.COMPLETED: - # Create BytesIO object to store bytes received from server in. - result_bytes = io.BytesIO() - result_bytes.seek(0) - - # Get result from result url using job id. - with requests.get( - url=f"{self.address}/result/{response.id}", - stream=True, - ) as stream: - # Total size of incoming data. - total_size = float(stream.headers["Content-length"]) - - with tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc="Downloading result", - ) as progress_bar: - # chunk_size=None so server determines chunk size. - for data in stream.iter_content(chunk_size=None): - progress_bar.update(len(data)) - result_bytes.write(data) - - # Move cursor to beginning of bytes. - result_bytes.seek(0) - - # Decode bytes with pickle and then into pydantic object. - result: "ResultModel" = ResultModel( - **torch.load( + + # If the response has no result data, it was too big and we need to stream it from the server. + if response.data is None: + # Create BytesIO object to store bytes received from server in. + result_bytes = io.BytesIO() + result_bytes.seek(0) + + # Get result from result url using job id. + with requests.get( + url=f"{self.address}/result/{response.id}", + stream=True, + ) as stream: + # Total size of incoming data. + total_size = float(stream.headers["Content-length"]) + + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc="Downloading result", + ) as progress_bar: + # chunk_size=None so server determines chunk size. + for data in stream.iter_content(chunk_size=None): + progress_bar.update(len(data)) + result_bytes.write(data) + + # Move cursor to beginning of bytes. + result_bytes.seek(0) + + # Decode bytes with pickle and then into pydantic object. + result = torch.load( result_bytes, map_location="cpu", weights_only=False ) - ) - # Close bytes - result_bytes.close() + # Close bytes + result_bytes.close() + + else: + + result = response.data + + # Load into pydantic object from dict. + result = ResultModel(**result) # Handle result - self.handle_result(result.value) + # This injects the .saved() values + self.object.remote_backend_handle_result_value(result.value) + + # If were receiving a streamed value: + elif response.status == ResponseModel.JobStatus.STREAM: - # Or if there was some error. - elif response.status == ResponseModel.JobStatus.ERROR: - raise Exception(str(response)) + # First item is arguments on how the RemoteMixin can get the correct StreamingDownload node. + # Second item is the steamed value from the remote service. + args, value = response.data - return response + # Get the local stream node in our intervention graph + node = self.object.remote_backend_get_stream_node(*args) + + # Inject it into the local intervention graph to kick off local execution. + node.set_value(value) def submit_request(self, request: "RequestModel") -> "ResponseModel": """Sends request to the remote endpoint and handles the response object. @@ -203,6 +252,8 @@ def submit_request(self, request: "RequestModel") -> "ResponseModel": (ResponseModel): Response. """ + from ...schema.Response import ResponseModel + response = requests.post( f"{self.address}/request", json=request.model_dump(exclude=["id", "received"]), @@ -211,7 +262,9 @@ def submit_request(self, request: "RequestModel") -> "ResponseModel": if response.status_code == 200: - return self.handle_response(response.json()) + response = ResponseModel(**response.json()) + + return self.handle_response(response) else: @@ -227,6 +280,8 @@ def get_response(self) -> "ResponseModel": (ResponseModel): Response. """ + from ...schema.Response import ResponseModel + response = requests.get( f"{self.address}/response/{self.job_id}", headers={"ndif-api-key": self.api_key}, @@ -234,7 +289,9 @@ def get_response(self) -> "ResponseModel": if response.status_code == 200: - return self.handle_response(response.json()) + response = ResponseModel(**response.json()) + + return self.handle_response(response) else: @@ -244,11 +301,15 @@ def blocking_request(self, request: "RequestModel"): """Send intervention request to the remote service while waiting for updates via websocket. Args: - request (RequestModel): Request. + request (RequestModel):Request. """ from ...schema.Response import ResponseModel + # We need to do some processing / optimizations on both the graph were sending remotely + # and our local intervention graph. In order handle the more complex Protocols for streaming. + preprocess(request, streaming=True) + # Create a socketio connection to the server. with socketio.SimpleClient( logger=logger, reconnection_attempts=10 @@ -265,15 +326,53 @@ def blocking_request(self, request: "RequestModel"): request.session_id = sio.sid # Submit request via - self.submit_request(request) + response = self.submit_request(request) + + # We need to tell the StreamingUploadProtocol how to use our websocket connection + # so it can upload values during execution to our job. + protocols.StreamingUploadProtocol.set( + lambda *args: self.stream_send(*args, job_id=response.id, sio=sio) + ) + + try: + # Loop until + while True: + + # Get pickled bytes value from the websocket. + response = sio.receive()[1] + # Convert to pydantic object. + response = ResponseModel.unpickle(response) + + # Handle the response. + self.handle_response(response) + + # Break when completed. + if response.status == ResponseModel.JobStatus.COMPLETED: + break + + except Exception as e: + + raise e + + finally: - # Loop until - while True: - if ( - self.handle_response(sio.receive()[1]).status - == ResponseModel.JobStatus.COMPLETED - ): - break + # Clear StreamingUploadProtocol state + protocols.StreamingUploadProtocol.set(None) + + def stream_send(self, value: Any, job_id: str, sio:socketio.SimpleClient): + """Upload some value to the remote service for some job id. + + Args: + value (Any): Value to upload + job_id (str): Job id. + sio (socketio.SimpleClient): Connected websocket client. + """ + + from ...schema.Request import StreamValueModel + + request = StreamValueModel(model_key=job_id, value=value) + + sio.emit('stream_upload', data=request.model_dump()) def non_blocking_request(self, request: "RequestModel" = None): """Send intervention request to the remote service if request provided. Otherwise get job status. @@ -316,3 +415,112 @@ def non_blocking_request(self, request: "RequestModel" = None): CONFIG.save() raise e + + +def preprocess(request: "RequestModel", streaming: bool = False): + """Optimizes the local and remote graph to handle streaming. Is required to use streaming protocols. + + Args: + request (RequestModel): Request to optimize. + streaming (bool, optional): If streaming. Defaults to False. + + Raises: + exception: _description_ + """ + + from ...schema.format.functions import get_function_name + from ...schema.format.types import FunctionModel, GraphModel, NodeModel + + # Exceptions might be resolved later during optimization so exceptions + # are stored here to added and removed. + # If there are any still in here after optimization, it raises the first one. + exceptions = {} + + def inner(graph_model: GraphModel): + """Optimizes the given remote GraphModel + + Args: + graph_model (GraphModel): Remote Graph Model to send remotely. + """ + + # GraphModel has an un-serialized reference to the real local Graph. + graph = graph_model.graph + + for node_name, node_model in list(graph_model.nodes.items()): + + # Get local nnsight Node + node = graph.nodes[node_name] + + # Get name of Node.target + function_name = node_model.target.function_name + + # If its a streaming download Node, we need to recursively remove these Nodes from the remote GraphModel. + # This is because we will be executing these nodes only locally when the root streaming node is download. + # This recursion ends at a streaming Upload Node and will resume remote execution of the intervention graph. + if streaming and function_name == get_function_name( + protocols.StreamingDownloadProtocol + ): + + def pop_stream_listeners(node: "Node"): + """Recursively removes listeners of streaming download nodes. + + Args: + node (Node): Node. + """ + + for node in node.listeners: + + # Also reset it to prepare for its local execution. + node.reset() + + if node.target is not protocols.StreamingUploadProtocol: + + # Remove from remote GraphModel + graph_model.nodes.pop(node.name, None) + # Also remove the exception for it. + exceptions.pop( + f"{graph_model.id}_{node.name}", None + ) + + pop_stream_listeners(node) + + # We also need to replace all args / dependencies of Upload Nodes to be the root stream Download Nodes. + # This is because remote Upload nodes cant depend on nodes that will be local of course. + # However it does need to depend on its root stream Download Nodes so the remote service only executes and waits at an Upload + # AFTER it sends any values via the stream Download Nodes. + else: + + graph_model.nodes[node.name].args = [] + graph_model.nodes[node.name].kwargs[node_name] = ( + NodeModel.Reference(name=node_name) + ) + + pop_stream_listeners(node) + + # Recurse into inner graphs. + elif function_name == get_function_name( + protocols.LocalBackendExecuteProtocol + ): + + inner(node_model.args[0].graph) + + + else: + # If its still a node that will be executed remotely: + if node_name in graph_model.nodes: + + # We need to see if its whitelisted. + try: + + FunctionModel.check_function_whitelist(function_name) + # Put exception in dict as it may be removed during further iterations. + except Exception as e: + + exceptions[f"{graph_model.id}_{node_name}"] = e + + inner(request.object.graph) + + # Raise any leftover exceptions + for exception in exceptions.values(): + + raise exception diff --git a/src/nnsight/contexts/session/Session.py b/src/nnsight/contexts/session/Session.py index 851abd14..ca0e98e2 100644 --- a/src/nnsight/contexts/session/Session.py +++ b/src/nnsight/contexts/session/Session.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ...models.mixins import RemoteableMixin from ...models.NNsightModel import NNsight + from ...tracing.Node import Node class Session(GraphBasedContext, RemoteMixin): @@ -140,6 +141,12 @@ def remote_backend_handle_result_value( graph.alive = False + def remote_backend_get_stream_node(self, name: str, graph_id: str) -> "Node": + + graph = self.bridge.id_to_graph[graph_id] + + return graph.nodes[name] + def remote_backend_cleanup(self): self.bridge = weakref.proxy(self.bridge) diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 7bc9175d..8671e352 100755 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -19,7 +19,6 @@ from typing_extensions import Self from . import util -from .contexts.Conditional import Conditional from .tracing import protocols from .tracing.Graph import Graph from .tracing.Node import Node @@ -52,7 +51,8 @@ def __init__(self, node: Node) -> None: self._grad: InterventionProxy def save(self) -> InterventionProxy: - """Method when called, indicates to the intervention graph to not delete the tensor values of the result. + """Adds a lock Node to prevent its value from being cleared where normally it would be cleared when its no longer needed to save memory. + Used to access values outside of the tracing context, after execution. Returns: InterventionProxy: Proxy. @@ -66,16 +66,34 @@ def save(self) -> InterventionProxy: return self - def stop(self) -> InterventionProxy: - """Method when called, indicates to the intervention graph to stop the execution of the model after this Proxy/Node is completed.. + def local(self) -> InterventionProxy: + """Streams value of this node locally when it becomes available remotely. + This then kicks off execution of the local intervention graph up until it hits an upload Node created from `remote()`. + + Is a no-op when not executing remotely. Returns: InterventionProxy: Proxy. """ - protocols.EarlyStopProtocol.add(self.node.graph, self.node) + return protocols.StreamingDownloadProtocol.add(self.node) - return self + def remote(self) -> InterventionProxy: + """Streams value of this node remotely when it becomes available locally. + The remote service will block until the local value is uploaded and received. + + Is a no-op when not executing remotely. + + Returns: + InterventionProxy: Proxy. + """ + + return protocols.StreamingUploadProtocol.add(self.node.graph, self.node) + + def stop(self) -> None: + """Method when called, indicates to the intervention graph to stop the execution of the model after this Proxy/Node is completed..""" + + protocols.EarlyStopProtocol.add(self.node.graph, self.node) def update(self, value: Union[Node, Any]) -> InterventionProxy: """Updates the value of the Proxy via the creation of the UpdateProtocol node. @@ -182,9 +200,7 @@ def shape(self) -> Collection[torch.Size]: return super().__getattr__("shape") - return util.apply( - self.node.proxy_value, lambda x: x.shape, torch.Tensor - ) + return util.apply(self.node.proxy_value, lambda x: x.shape, torch.Tensor) @property def device(self) -> Collection[torch.device]: @@ -203,9 +219,7 @@ def device(self) -> Collection[torch.device]: return super().__getattr__("device") - return util.apply( - self.node.proxy_value, lambda x: x.device, torch.Tensor - ) + return util.apply(self.node.proxy_value, lambda x: x.device, torch.Tensor) @property def dtype(self) -> Collection[torch.device]: @@ -224,9 +238,7 @@ def dtype(self) -> Collection[torch.device]: return super().__getattr__("dtype") - return util.apply( - self.node.proxy_value, lambda x: x.dtype, torch.Tensor - ) + return util.apply(self.node.proxy_value, lambda x: x.dtype, torch.Tensor) class InterventionProtocol(Protocol): @@ -422,9 +434,7 @@ def intervene( # Updates the count of intervention node calls. # If count matches call_iter, time to inject value into node. - if call_iter != intervention_handler.count( - intervention_node_name - ): + if call_iter != intervention_handler.count(intervention_node_name): continue diff --git a/src/nnsight/schema/Request.py b/src/nnsight/schema/Request.py index 5d7adaab..62069a2e 100644 --- a/src/nnsight/schema/Request.py +++ b/src/nnsight/schema/Request.py @@ -43,8 +43,25 @@ def deserialize(self, model: NNsight) -> "RemoteMixin": handler = DeserializeHandler(model=model) - object = TypeAdapter( + object: OBJECT_TYPES = TypeAdapter( OBJECT_TYPES, config=RequestModel.model_config ).validate_python(json.loads(self.object)) return object.deserialize(handler) + +class StreamValueModel(BaseModel): + + model_config = ConfigDict( + arbitrary_types_allowed=True, protected_namespaces=() + ) + + model_key: str + value: ValueTypes + + def deserialize(self, model:NNsight): + + handler = DeserializeHandler(model=model) + + return self.value.deserialize(handler) + + \ No newline at end of file diff --git a/src/nnsight/schema/Response.py b/src/nnsight/schema/Response.py index c9a15527..12095eba 100644 --- a/src/nnsight/schema/Response.py +++ b/src/nnsight/schema/Response.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import logging from datetime import datetime from enum import Enum @@ -20,13 +21,15 @@ class ResultModel(BaseModel): def from_graph(cls, graph: Graph) -> Dict[str, Any]: saves = { - name: util.apply(node.value, lambda x: x.detach().cpu(), torch.Tensor) + name: util.apply( + node.value, lambda x: x.detach().cpu(), torch.Tensor + ) for name, node in graph.nodes.items() if node.done() } return saves - + class ResponseModel(BaseModel): class JobStatus(Enum): @@ -35,13 +38,15 @@ class JobStatus(Enum): RUNNING = "RUNNING" COMPLETED = "COMPLETED" LOG = "LOG" + STREAM = "STREAM" ERROR = "ERROR" id: str status: JobStatus - description: str - received: datetime = None + description: Optional[str] = "" + data: Optional[Any] = None + received: Optional[datetime] = None session_id: Optional[str] = None def __str__(self) -> str: @@ -50,7 +55,41 @@ def __str__(self) -> str: def log(self, logger: logging.Logger) -> ResponseModel: if self.status == ResponseModel.JobStatus.ERROR: logger.error(str(self)) + raise SystemExit("Remote exception.") + elif self.status == ResponseModel.JobStatus.STREAM: + pass else: logger.info(str(self)) return self + + def pickle(self) -> bytes: + """Pickles self and returns bytes. + + Returns: + bytes: Pickled ResponseModel + """ + + with io.BytesIO() as file: + + torch.save(self.model_dump(exclude_unset=True), file) + + file.seek(0) + + return file.read() + + @classmethod + def unpickle(cls, data: bytes) -> ResponseModel: + """Loads a ResponseModel from pickled bytes. + + Args: + data (bytes): Pickled ResoonseModel. + + Returns: + ResponseModel: Response. + """ + + with io.BytesIO(data) as file: + return ResponseModel( + **torch.load(file, map_location="cpu", weights_only=False) + ) diff --git a/src/nnsight/schema/format/types.py b/src/nnsight/schema/format/types.py index be97f25b..c885167a 100644 --- a/src/nnsight/schema/format/types.py +++ b/src/nnsight/schema/format/types.py @@ -7,8 +7,15 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch -from pydantic import (BaseModel, ConfigDict, Field, Strict, field_validator, - model_serializer) +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + Strict, + field_validator, + model_serializer, +) from pydantic.functional_validators import AfterValidator from typing_extensions import Annotated @@ -51,12 +58,13 @@ class BaseNNsightModel(BaseModel): def deserialize(self, handler: DeserializeHandler): raise NotImplementedError() + def try_deserialize(value: BaseNNsightModel | Any, handler: DeserializeHandler): - + if isinstance(value, BaseNNsightModel): - + return value.deserialize(handler) - + return value @@ -77,27 +85,25 @@ def deserialize(self, handler: DeserializeHandler) -> Node: target: Union[FunctionModel, FunctionType] args: List[ValueTypes] = [] kwargs: Dict[str, ValueTypes] = {} - condition: None | Union[ - NodeReferenceType, NodeModel.Reference - ] = None - - @model_serializer(mode='wrap') + condition: None | Union[NodeReferenceType, NodeModel.Reference] = None + + @model_serializer(mode="wrap") def serialize_model(self, handler): - + dump = handler(self) - + if self.condition is None: - - dump.pop('condition') - + + dump.pop("condition") + if not self.kwargs: - - dump.pop('kwargs') - + + dump.pop("kwargs") + if not self.args: - - dump.pop('args') - + + dump.pop("args") + return dump def deserialize(self, handler: DeserializeHandler) -> Node: @@ -110,13 +116,14 @@ def deserialize(self, handler: DeserializeHandler) -> Node: target=self.target.deserialize(handler), args=[try_deserialize(value, handler) for value in self.args], kwargs={ - key: try_deserialize(value, handler) for key, value in self.kwargs.items() + key: try_deserialize(value, handler) + for key, value in self.kwargs.items() }, name=self.name, ).node node.cond_dependency = try_deserialize(self.condition, handler) - + if isinstance(node.cond_dependency, Node): node.cond_dependency.listeners.append(weakref.proxy(node)) @@ -128,6 +135,7 @@ def deserialize(self, handler: DeserializeHandler) -> Node: return node + class TensorModel(BaseNNsightModel): type_name: Literal["TENSOR"] = "TENSOR" @@ -153,7 +161,7 @@ def deserialize(self, handler: DeserializeHandler) -> slice: return slice( try_deserialize(self.start, handler), try_deserialize(self.stop, handler), - try_deserialize(self.step, handler) + try_deserialize(self.step, handler), ) @@ -196,7 +204,10 @@ class DictModel(BaseNNsightModel): values: Dict[str, ValueTypes] def deserialize(self, handler: DeserializeHandler) -> dict: - return {key: try_deserialize(value, handler) for key, value in self.values.items()} + return { + key: try_deserialize(value, handler) + for key, value in self.values.items() + } class FunctionWhitelistError(Exception): @@ -209,7 +220,6 @@ class FunctionModel(BaseNNsightModel): function_name: str - @field_validator("function_name") @classmethod def check_function_whitelist(cls, qualname: str) -> str: if qualname not in FUNCTIONS_WHITELIST: @@ -220,6 +230,9 @@ def check_function_whitelist(cls, qualname: str) -> str: return qualname def deserialize(self, handler: DeserializeHandler) -> FUNCTION: + + FunctionModel.check_function_whitelist(self.function_name) + return FUNCTIONS_WHITELIST[self.function_name] @@ -227,13 +240,18 @@ class GraphModel(BaseNNsightModel): type_name: Literal["GRAPH"] = "GRAPH" + # We have a reference to the real Graph in the pydantic to be used by optimization logic + graph: Graph = Field(exclude=True, default=None, validate_default=False) + id: int sequential: bool nodes: Dict[str, Union["NodeModel", "NodeType"]] def deserialize(self, handler: DeserializeHandler) -> Graph: - graph = Graph(validate=False, sequential=self.sequential, graph_id=self.id) + graph = Graph( + validate=False, sequential=self.sequential, graph_id=self.id + ) handler.graph = graph handler.nodes = self.nodes @@ -271,10 +289,14 @@ def deserialize(self, handler: DeserializeHandler) -> Tracer: handler.graph = graph - kwargs = {key: try_deserialize(value, handler) for key, value in self.kwargs.items()} + kwargs = { + key: try_deserialize(value, handler) + for key, value in self.kwargs.items() + } invoker_inputs = [ - try_deserialize(invoker_input, handler) for invoker_input in self.invoker_inputs + try_deserialize(invoker_input, handler) + for invoker_input in self.invoker_inputs ] tracer = Tracer( @@ -342,7 +364,10 @@ def deserialize(self, handler: DeserializeHandler) -> Session: Graph, AfterValidator( lambda value: GraphModel( - id=value.id, sequential=value.sequential, nodes=value.nodes + id=value.id, + sequential=value.sequential, + nodes=value.nodes, + graph=value, ) ), ] @@ -359,27 +384,39 @@ def deserialize(self, handler: DeserializeHandler) -> Session: SliceType = Annotated[ slice, AfterValidator( - lambda value: SliceModel(start=value.start, stop=value.stop, step=value.step) + lambda value: SliceModel( + start=value.start, stop=value.stop, step=value.step + ) ), ] EllipsisType = Annotated[ - type(...), # It will be better to use EllipsisType, but it requires python>=3.10 + type( + ... + ), # It will be better to use EllipsisType, but it requires python>=3.10 AfterValidator(lambda value: EllipsisModel()), ] -ListType = Annotated[list, AfterValidator(lambda value: ListModel(values=value))] +ListType = Annotated[ + list, AfterValidator(lambda value: ListModel(values=value)) +] TupleType = Annotated[ - tuple, Strict(), AfterValidator(lambda value: TupleModel(values=list(value))) + tuple, + Strict(), + AfterValidator(lambda value: TupleModel(values=list(value))), ] -DictType = Annotated[dict, AfterValidator(lambda value: DictModel(values=value))] +DictType = Annotated[ + dict, AfterValidator(lambda value: DictModel(values=value)) +] FunctionType = Annotated[ FUNCTION, - AfterValidator(lambda value: FunctionModel(function_name=get_function_name(value))), + AfterValidator( + lambda value: FunctionModel(function_name=get_function_name(value)) + ), ] NodeReferenceType = Annotated[ @@ -412,7 +449,9 @@ def deserialize(self, handler: DeserializeHandler) -> Session: IteratorType = Annotated[ Iterator, - AfterValidator(lambda value: IteratorModel(graph=value.graph, data=value.data)), + AfterValidator( + lambda value: IteratorModel(graph=value.graph, data=value.data) + ), ] SessionType = Annotated[ diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index 73b253e7..3efd92f4 100755 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -403,21 +403,31 @@ def set_value(self, value: Any) -> None: logger.info(f"=> SET({self.name})") + self.update_listeners() + + self.update_dependencies() + + if self.done() and self.redundant(): + self.destroy() + + def update_listeners(self): + """Updates remaining_dependencies of listeners. If they are now fulfilled, execute them.""" + for listener in self.listeners: listener.remaining_dependencies -= 1 if listener.fulfilled() and not self.graph.sequential: listener.execute() + def update_dependencies(self): + """Updates remaining_listeners of dependencies. If they are now redundant, destroy them.""" + for dependency in self.arg_dependencies: dependency.remaining_listeners -= 1 if dependency.redundant(): dependency.destroy() - if self.done() and self.redundant(): - self.destroy() - def destroy(self) -> None: """Removes the reference to the node's value and logs it's destruction.""" @@ -461,7 +471,11 @@ def visualize( styles = { "node": {"color": "black", "shape": "ellipse"}, - "label": (self.target if isinstance(self.target, str) else self.target.__name__), + "label": ( + self.target + if isinstance(self.target, str) + else self.target.__name__ + ), "arg": defaultdict(lambda: {"color": "gray", "shape": "box"}), "arg_kname": defaultdict(lambda: None), "edge": defaultdict(lambda: "solid"), @@ -473,7 +487,9 @@ def visualize( self.target, protocols.Protocol ): styles = self.target.style() - viz_graph.add_node(node_name, label=styles["label"], **styles["node"]) + viz_graph.add_node( + node_name, label=styles["label"], **styles["node"] + ) if ( recursive and self.target == protocols.LocalBackendExecuteProtocol @@ -500,7 +516,9 @@ def visualize( viz_graph, recursive, node_name + "_" ) else: - viz_graph.add_node(node_name, label=styles["label"], **styles["node"]) + viz_graph.add_node( + node_name, label=styles["label"], **styles["node"] + ) def visualize_args(arg_collection): """Recursively visualizes the arguments of this node. @@ -518,9 +536,11 @@ def visualize_args(arg_collection): if isinstance(arg, Iterable): for element in arg: if isinstance(element, Node): - dep_name = element.visualize(viz_graph, recursive, backend_name) + dep_name = element.visualize( + viz_graph, recursive, backend_name + ) iter_val_dependencies.append(dep_name) - + name = node_name if isinstance(arg, torch.Tensor): name += f"_Tensor_{key}" @@ -541,7 +561,9 @@ def visualize_args(arg_collection): viz_graph.add_node(name, label=label, **styles["arg"][key]) for dep_name in iter_val_dependencies: - viz_graph.add_edge(dep_name, name, style="dashed", color="gray") + viz_graph.add_edge( + dep_name, name, style="dashed", color="gray" + ) viz_graph.add_edge(name, node_name, style=styles["edge"][key]) diff --git a/src/nnsight/tracing/protocols.py b/src/nnsight/tracing/protocols.py index 19a52a13..ce643312 100755 --- a/src/nnsight/tracing/protocols.py +++ b/src/nnsight/tracing/protocols.py @@ -1,7 +1,10 @@ +import asyncio import inspect import weakref from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from io import BytesIO +from threading import Thread +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import torch from torch._subclasses.fake_tensor import FakeCopyMode, FakeTensorMode @@ -138,9 +141,7 @@ def execute(cls, node: "Node") -> None: except: device = None - args, kwargs = node.prepare_inputs( - (node.args, node.kwargs), device=device - ) + args, kwargs = node.prepare_inputs((node.args, node.kwargs), device=device) module_path, *args = args @@ -442,9 +443,7 @@ class BridgeProtocol(Protocol): class BridgeException(Exception): def __init__(self): - super.__init__( - "Must define a Session context to make use of the Bridge" - ) + super.__init__("Must define a Session context to make use of the Bridge") @classmethod def add(cls, node: "Node") -> "InterventionProxy": @@ -485,7 +484,7 @@ def execute(cls, node: "Node") -> None: # Value node is Lock Node's only arg value_node: "Node" = lock_node.args[0] - + if value_node.done(): # Set value to that of the value Node. @@ -672,7 +671,7 @@ def add(cls, graph: "Graph", default: Any = None) -> "InterventionProxy": @classmethod def execute(cls, node: Node) -> None: - + node.set_value(node.args[0]) @classmethod @@ -742,9 +741,7 @@ class ConditionalProtocol(Protocol): attachment_name = "nnsight_conditional_manager" @classmethod - def add( - cls, graph: "Graph", condition: Union["Node", Any] - ) -> "InterventionProxy": + def add(cls, graph: "Graph", condition: Union["Node", Any]) -> "InterventionProxy": return graph.create(target=cls, proxy_value=True, args=[condition]) @@ -791,9 +788,7 @@ def has_conditional(cls, graph: "Graph") -> bool: return cls.attachment_name in graph.attachments.keys() @classmethod - def get_conditional( - cls, graph: "Graph", cond_node_name: str - ) -> "Conditional": + def get_conditional(cls, graph: "Graph", cond_node_name: str) -> "Conditional": """Gets the ConditionalProtocol node by its name. Args: @@ -863,9 +858,7 @@ def is_node_conditioned(cls, node: "Node") -> bool: bool: Whether the Node is conditioned. """ - return node.graph.attachments[cls.attachment_name].is_node_conditioned( - node - ) + return node.graph.attachments[cls.attachment_name].is_node_conditioned(node) @classmethod def style(cls) -> Dict[str, Any]: @@ -900,9 +893,7 @@ class UpdateProtocol(Protocol): """ @classmethod - def add( - cls, node: "Node", new_value: Union[Node, Any] - ) -> "InterventionProxy": + def add(cls, node: "Node", new_value: Union[Node, Any]) -> "InterventionProxy": """Creates an UpdateProtocol node. Args: @@ -937,9 +928,7 @@ def execute(cls, node: "Node") -> None: if value_node.target == BridgeProtocol: value_node._value = new_value bridge = BridgeProtocol.get_bridge(value_node.graph) - lock_node = bridge.id_to_graph[value_node.args[0]].nodes[ - value_node.args[1] - ] + lock_node = bridge.id_to_graph[value_node.args[0]].nodes[value_node.args[1]] value_node = lock_node.args[0] value_node._value = new_value @@ -963,3 +952,72 @@ def style(cls) -> Dict[str, Any]: "arg_kname": defaultdict(lambda: None), # Argument label key word "edge": defaultdict(lambda: "solid"), } # Argument edge display + + +class StreamingDownloadProtocol(Protocol): + + @classmethod + def add(cls, node: Node) -> "InterventionProxy": + """Add streaming download Node to the intervention graph. + + Args: + node (Node): Node to download value of locally when available remotely. + """ + + return node.create(target=cls, proxy_value=None, args=[node]) + + @classmethod + def execute(cls, node: "Node"): + """When executing remotely, the local version of this Node type has its value set directly by `RemoteBackend`, not via `.execute(...)` + The remote version streams the value in a ResponseModel object. + + Is a no-op when not executing remotely. + """ + + value_node = node.args[0] + + node.set_value(value_node.value) + + +class StreamingUploadProtocol(Protocol): + + send: Callable = None + + @classmethod + def set(cls, fn: Callable): + + cls.send = fn + + @classmethod + def add(cls, graph: "Graph", value: Any) -> "InterventionProxy": + """Add streaming upload Node to the intervention graph. + + Args: + graph (Graph): Graph to add Node to. + value (Any): Value to upload remotely when available locally. + """ + + return graph.create(target=cls, proxy_value=None, args=[value]) + + @classmethod + def execute(cls, node: "Node"): + """When executing remotely, the local version of this Node calls `cls.send` to upload the its value to a waiting remote service. + The remote version blocks and waits until it receives the value from its local counterpart. + + Is a no-op when not executing remotely. + + Args: + node (Node): Node to upload remotely. + """ + + value = node.prepare_inputs(node.args[0]) + + if cls.send is not None: + + cls.send(value) + + node.update_dependencies() + + else: + + node.set_value(value) diff --git a/src/nnsight/util.py b/src/nnsight/util.py index 42d52956..83e92275 100755 --- a/src/nnsight/util.py +++ b/src/nnsight/util.py @@ -12,6 +12,7 @@ Optional, Tuple, Type, + TypeVar, Union, ) @@ -22,8 +23,10 @@ # TODO Have an Exception you can raise to stop apply early +T = TypeVar('T') + def apply( - data: Any, fn: Callable, cls: Type, inplace: bool = False + data: Any, fn: Callable[[T], Any], cls: Type[T], inplace: bool = False ) -> Collection: """Applies some function to all members of a collection of a give type (or types)