Skip to content

Commit

Permalink
Merge pull request #285 from ndif-team/local
Browse files Browse the repository at this point in the history
Local
  • Loading branch information
JadenFiotto-Kaufman authored Nov 7, 2024
2 parents a91e04c + 7ff48fe commit f84b069
Show file tree
Hide file tree
Showing 24 changed files with 613 additions and 437 deletions.
44 changes: 8 additions & 36 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ def inner(cls, fake_mode, elem, device, constant=None):

DEFAULT_PATCHER.__enter__()

from .tracing.contexts import GlobalTracingContext

apply = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply
log = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.log
from .intervention.contexts import GlobalInterventionTracingContext

apply = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.apply
log = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.log
local = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.local
cond = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.cond
iter = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.iter
stop = GlobalInterventionTracingContext.GLOBAL_TRACING_CONTEXT.stop

def trace(fn):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`.
Expand All @@ -103,7 +106,7 @@ def trace(fn):

@wraps(fn)
def inner(*args, **kwargs):

return apply(fn, *args, **kwargs)

return inner
Expand All @@ -127,37 +130,6 @@ def inner(*args, **kwargs):
from .intervention.graph import InterventionProxy


# def local(object: Callable | InterventionProxy):
# """Helper decorator to add a function to the intervention graph via `.apply(...)`
# AND convert all input Proxies to local ones via `.local()`.

# If a non-function is passed in, its assumed to be an `InterventionProxy` and `.local()` is called and returned.

# Args:
# object ( Callable | InterventionProxy): Function to apply or Proxy to make local.

# Returns:
# Callable | InterventionProxy: Traceable local function or local Proxy.
# """

# if inspect.isroutine(object):

# fn = trace(object)

# @wraps(fn)
# def inner(*args, **kwargs):

# args, kwargs = util.apply(
# (args, kwargs), lambda x: x.local(), InterventionProxy
# )

# return fn(*args, **kwargs)

# return inner

# return object.local()


# def remote(object: Callable | Any):
# """Helper decorator to add a function to the intervention graph via `.apply(...)`
# AND convert all input Proxies to downloaded local ones via `.local()`
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
API:
APIKEY: null
APIKEY: rhJV6e3LGhtXlkB47z9W
FORMAT: json
HOST: ndif.dev
JOB_ID: null
SSL: true
ZLIB: true
FORMAT: json
APP:
LOGGING: false
REMOTE_LOGGING: true
8 changes: 8 additions & 0 deletions src/nnsight/intervention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@

"""
The `intervention` module extends the `tracing` module to add PyTorch specific interventions to a given computation graph.
It defines its own: protocols, contexts, backends and graph primitives to achieve this.
"""
from .base import NNsight
from .envoy import Envoy
Loading

0 comments on commit f84b069

Please sign in to comment.