diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index bfa2fcd415..859c09398e 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -327,17 +327,17 @@ def merge(self, state: State, *states: State) -> Node: def apply( self, state: State, *states: State ) -> ApplyCaller[tuple[State, 'GraphDef[Node]']]: - accessesor = DelayedAccessor() + accessor = DelayedAccessor() def _apply( - accessesor, *args, **kwargs + accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[State, GraphDef[Node]]]: module = self.merge(state, *states) - fn = accessesor(module) + fn = accessor(module) out = fn(*args, **kwargs) return out, graph_flatten(module) - return CallableProxy(_apply, accessesor) # type: ignore + return CallableProxy(_apply, accessor) # type: ignore def make_empty(self) -> Node: return self.merge(State({})) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 47af5792a6..1d32678940 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -283,21 +283,19 @@ def init(cls: type[M], *args, **kwargs) -> tuple[State, GraphDef[M]]: @classmethod @property def create_abstract(cls: type[M]) -> type[M]: - accessesor = DelayedAccessor() - def lift_rngs(kwargs: dict[str, tp.Any]): if 'rngs' in kwargs and isinstance(kwargs['rngs'], Rngs): kwargs['rngs'] = kwargs['rngs'].copy() return kwargs - def _create_abstract(accessesor, *args, **kwargs): - constructor = accessesor(cls) + def _create_abstract(accessor: DelayedAccessor, *args, **kwargs): + constructor = accessor(cls) state, graphdef = jax.eval_shape( lambda: constructor(*args, **lift_rngs(kwargs)).split() ) return graphdef.merge(state) - return CallableProxy(_create_abstract, accessesor) # type: ignore + return CallableProxy(_create_abstract) # type: ignore def clone(self: M) -> M: return merge(self.split()) @@ -404,15 +402,13 @@ def pop( @property def apply(self: M) -> ApplyCaller[M]: - accessesor = DelayedAccessor() - - def _apply(accessesor, *args, **kwargs) -> tuple[tp.Any, M]: + def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[tp.Any, M]: module = self.clone() - fn = accessesor(module) + fn = accessor(module) out = fn(*args, **kwargs) return out, module - return CallableProxy(_apply, accessesor) # type: ignore + return CallableProxy(_apply) # type: ignore def update(self: M, update: Updates[M], *updates: Updates[M]) -> None: updates = (update, *updates) diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/experimental/nnx/nnx/proxy_caller.py index 03c6fd2ef2..d67b6f3069 100644 --- a/flax/experimental/nnx/nnx/proxy_caller.py +++ b/flax/experimental/nnx/nnx/proxy_caller.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import dataclasses import typing as tp @@ -20,28 +21,6 @@ A = tp.TypeVar('A') -class _ProxyContext(tpe.Protocol): - def __call__(self, accessor: 'DelayedAccessor', /, *args, **kwargs) -> tp.Any: - ... - - -@dataclasses.dataclass -class CallableProxy: - _proxy_context: _ProxyContext - _proxy_callable: tp.Callable[..., tp.Any] - - def __call__(self, *args, **kwargs): - return self._proxy_context(self._proxy_callable, *args, **kwargs) - - def __getattr__(self, name) -> 'CallableProxy': - return CallableProxy( - self._proxy_context, getattr(self._proxy_callable, name) - ) - - def __getitem__(self, key) -> 'CallableProxy': - return CallableProxy(self._proxy_context, self._proxy_callable[key]) - - def _identity(x): return x @@ -60,6 +39,28 @@ def __getitem__(self, key): return DelayedAccessor(lambda x: x[key]) +class _AccessorCall(tpe.Protocol): + def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> tp.Any: + ... + + +class CallableProxy: + def __init__( + self, callable: _AccessorCall, accessor: DelayedAccessor | None = None + ): + self._callable = callable + self._accessor = DelayedAccessor() if accessor is None else accessor + + def __call__(self, *args, **kwargs): + return self._callable(self._accessor, *args, **kwargs) + + def __getattr__(self, name) -> 'CallableProxy': + return CallableProxy(self._callable, getattr(self._accessor, name)) + + def __getitem__(self, key) -> 'CallableProxy': + return CallableProxy(self._callable, self._accessor[key]) + + class ApplyCaller(tp.Protocol, tp.Generic[A]): def __getattr__(self, __name) -> 'ApplyCaller[A]': ... diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 5c400b04f9..8bc90a069e 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -77,7 +77,7 @@ def _check_args(args: tuple[tp.Any, ...]): class LiftedModule(Module, tp.Generic[M]): @abstractmethod - def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> tp.Any: + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: ... @property @@ -92,11 +92,11 @@ def __call__(self, *args, **kwargs) -> tp.Any: def call(self) -> tp.Any: module = self - def check_and_call(*args, **kwargs): + def check_and_call(accessor: DelayedAccessor, *args, **kwargs): _check_args(args) - return self._call(*args, **kwargs) + return self._call(accessor, *args, **kwargs) - proxy = CallableProxy(check_and_call, DelayedAccessor()) + proxy = CallableProxy(check_and_call) while isinstance(module._submodule, LiftedModule): module = module._submodule @@ -301,8 +301,8 @@ def jit_call_module(module, *args, **kwargs): def _submodule(self) -> M: return self.jit_module - def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> Any: - self.accessor = accessesor + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: + self.accessor = accessor try: out = jit_apply(self.jitted_fn, self.jit_module, args, kwargs) finally: @@ -450,9 +450,9 @@ def __init__( def _submodule(self) -> M: return self.grad_module - def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> Any: + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: def grad_call_apply(module, *args, **kwargs): - return accessesor(module)(*args, **kwargs) + return accessor(module)(*args, **kwargs) return grad_apply( self.options, grad_call_apply, self.grad_module, *args, **kwargs @@ -723,7 +723,7 @@ def _submodule(self) -> M: return self.scan_module def _call( - self, accessesor: DelayedAccessor, *args, **kwargs + self, accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tp.Any]: if len(args) < 1: raise TypeError( @@ -733,7 +733,7 @@ def _call( carry_arg, args = args[0], args[1:] def scan_call_apply(module, *args, **kwargs): - return accessesor(module)(*args, **kwargs) + return accessor(module)(*args, **kwargs) return scan_apply( self.options, @@ -1176,12 +1176,12 @@ def _submodule(self) -> M: def _call( self, - accessesor: DelayedAccessor, + accessor: DelayedAccessor, *args, rngs: tp.Optional[rnglib.Rngs] = None, ) -> tp.Any: def remat_call_apply(module, *args, **kwargs): - return accessesor(module)(*args, **kwargs) + return accessor(module)(*args, **kwargs) return remat_apply( self.options, @@ -1366,12 +1366,12 @@ def _submodule(self) -> M: return self.vmap_module def _call( - self, accessesor: DelayedAccessor, *args, **kwargs + self, accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tp.Any]: _check_args(args) def vmap_call_apply(module, *args, **kwargs): - return accessesor(module)(*args, **kwargs) + return accessor(module)(*args, **kwargs) return vmap_apply( self.options,