Skip to content

Commit

Permalink
Merge pull request #268 from ndif-team/streaming-protocol
Browse files Browse the repository at this point in the history
Streaming protocol
  • Loading branch information
JadenFiotto-Kaufman authored Oct 10, 2024
2 parents 3a31696 + 078e079 commit 2dd23c9
Show file tree
Hide file tree
Showing 12 changed files with 675 additions and 170 deletions.
100 changes: 93 additions & 7 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# #
# :::: ::: :::: ::: :::::::: ::::::::::: :::::::: ::: ::: ::::::::::: ::::::: :::::::: #
# :+:+: :+: :+:+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: #
Expand All @@ -8,10 +8,10 @@
# #+# #+#+# #+# #+#+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #
# ### #### ### #### ######## ########### ######## ### ### ### ####### ### ######## #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import os
from functools import wraps
from typing import Dict, Union
from typing import Callable, Dict, Union

from importlib.metadata import version, PackageNotFoundError

Expand Down Expand Up @@ -56,11 +56,11 @@
from torch._subclasses.fake_tensor import FakeTensor


def _bool(self):
def fake_bool(self):
return True


DEFAULT_PATCHER.add(Patch(FakeTensor, _bool, "__bool__"))
DEFAULT_PATCHER.add(Patch(FakeTensor, fake_bool, "__bool__"))


def fake_tensor_new_wrapper(fn):
Expand Down Expand Up @@ -118,10 +118,11 @@ def noop(input: torch.Tensor, *args, **kwargs):
)

import warnings

_str = str
_bool = bool

try:



from torch.amp.autocast_mode import autocast, is_autocast_available

Expand Down Expand Up @@ -555,3 +556,88 @@ def set_module_tensor_to_device(
apply = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply
log = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.log
cond = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.cond

import inspect

from . import util
from .intervention import InterventionProxy


def trace(fn: Callable):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`.
This is opposed to entering the function during tracing and tracing all inner operations.
Args:
fn (Callable): Function to apply.
Returns:
Callable: Traceable function.
"""

@wraps(fn)
def inner(*args, **kwargs):

return apply(fn, *args, **kwargs)

return inner


def local(object: Callable | InterventionProxy):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to local ones via `.local()`.
If a non-function is passed in, its assumed to be an `InterventionProxy` and `.local()` is called and returned.
Args:
object ( Callable | InterventionProxy): Function to apply or Proxy to make local.
Returns:
Callable | InterventionProxy: Traceable local function or local Proxy.
"""

if inspect.isroutine(object):

fn = trace(object)

@wraps(fn)
def inner(*args, **kwargs):

args, kwargs = util.apply(
(args, kwargs), lambda x: x.local(), InterventionProxy
)

return fn(*args, **kwargs)

return inner

return object.local()


def remote(object: Callable | Any):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to downloaded local ones via `.local()`
AND convert the output to an uploaded remote one via `remote()`.
If a non-function is passed in, `remote(object)` is called and returned.
Args:
object ( Callable | Any): Function to apply or object to make remote.
Returns:
Callable | InterventionProxy: Traceable local -> remote function or remote Proxy.
"""

if inspect.isroutine(object):

fn = local(object)

@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(
fn(*args, **kwargs)
)

return inner

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(object)
12 changes: 12 additions & 0 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def log(self, *data: Any) -> None:
data (Any): Data to print.
"""
self.apply(print, *data)

def remote(self, data:Any) -> InterventionProxy:
"""Streams data remotely when it becomes available locally.
The remote service will block until the local value is uploaded and received.
Is a no-op when not executing remotely.
Returns:
InterventionProxy: Proxy.
"""

return protocols.StreamingUploadProtocol.add(self.graph, data)

def bool(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable bool."""
Expand Down
4 changes: 4 additions & 0 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from ..models.mixins import RemoteableMixin
from ..models.NNsightModel import NNsight
from ..tracing.Node import Node


class Tracer(GraphBasedContext, RemoteMixin, BridgeMixin, EditMixin):
Expand Down Expand Up @@ -179,6 +180,9 @@ def remote_backend_handle_result_value(self, value: Dict[str, Any]) -> None:
# TODO : graph mismatch handle. hash json ?
for node_name, node_value in value.items():
self.graph.nodes[node_name]._value = node_value

def remote_backend_get_stream_node(self, name: str, graph_id: str) -> "Node":
return self.graph.nodes[name]

def remote_backend_cleanup(self):

Expand Down
Loading

0 comments on commit 2dd23c9

Please sign in to comment.