Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(diff/zero_order): add OOP API for zero-order differentiation #125

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added


- Add object-oriented modules support for zero-order differentiation by [@XuehaiPan](https://github.com/XuehaiPan) in [#125](https://github.com/metaopt/torchopt/pull/125).

### Changed

Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ clang-format: clang-format-install
# Documentation

addlicense: addlicense-install
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022 -check $(SOURCE_FOLDERS)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") -check $(SOURCE_FOLDERS)

docstyle: docs-install
make -C docs clean
Expand All @@ -162,7 +162,7 @@ format: py-format-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES) tutorials
$(CLANG_FORMAT) -style=file -i $(CXX_FILES)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022 $(SOURCE_FOLDERS)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)

clean-py:
find . -type f -name '*.py[co]' -delete
Expand Down
53 changes: 47 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Users need to define the stationary condition/objective function and the inner-l
```python
# Inherited from the class ImplicitMetaGradientModule
# Optionally specify the linear solver (conjugate gradient or Neumann series)
class InnerNet(ImplicitMetaGradientModule, linear_solver):
class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):
def __init__(self, meta_param):
super().__init__()
self.meta_param = meta_param
Expand Down Expand Up @@ -293,17 +293,58 @@ Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Ord

#### Functional API <!-- omit in toc -->

For zero-order differentiation, users need to define the forward pass calculation and the noise sampling procedure. TorchOpt provides the decorator to wrap the forward function for enabling zero-order differentiation.

```python
# Customize the noise sampling function in ES
def sample(sample_shape):
def distribution(sample_shape):
# Generate a batch of noise samples
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return sample_noise
return noise_batch

# Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)

# Specify method and hyper-parameter of ES
@torchopt.diff.zero_order(sample, method)
@torchopt.diff.zero_order(distribution, method)
def forward(params, batch, labels):
# forward process
return output
# Forward process
...
return objective # the returned tensor should be a scalar tensor
```

#### OOP API <!-- omit in toc -->

TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style.
Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`.

```python
# Inherited from the class ZeroOrderGradientModule
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
class Net(ZeroOrderGradientModule, method=method, num_samples=num_samples, sigma=sigma):
def __init__(self, ...):
...

def forward(self, batch):
# Forward process
...
return objective # the returned tensor should be a scalar tensor

def sample(self, sample_shape=torch.Size()):
# Generate a batch of noise samples
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return noise_batch

# Get model and data
net = Net(...)
data = ...

# Forward pass
loss = Net(data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, net.parameters())
```

--------------------------------------------------------------------------------
Expand Down
58 changes: 49 additions & 9 deletions tests/test_zero_order.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.types

import helpers
Expand All @@ -30,20 +31,17 @@ class FcNet(nn.Module):
def __init__(self, dim, out):
super().__init__()
self.fc = nn.Linear(in_features=dim, out_features=out, bias=True)
nn.init.ones_(self.fc.weight)
nn.init.zeros_(self.fc.bias)

def forward(self, x):
return self.fc(x)


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
lr=[1e-2, 1e-3],
method=['naive', 'forward', 'antithetic'],
sigma=[0.01, 0.1, 1],
)
def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> None:
def test_zero_order(lr: float, method: str, sigma: float) -> None:
helpers.seed_everything(42)
input_size = 32
output_size = 1
Expand All @@ -59,21 +57,63 @@ def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) ->
y = torch.randn(input_size) * coef
distribution = torch.distributions.Normal(loc=0, scale=1)

@torchopt.diff.zero_order.zero_order(
@torchopt.diff.zero_order(
distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples
)
def forward_process(params, fn, x, y):
y_pred = fn(params, x)
loss = torch.mean((y - y_pred) ** 2)
loss = F.mse_loss(y_pred, y)
return loss

optimizer = torchopt.adam(lr=lr)
opt_state = optimizer.init(params)
opt_state = optimizer.init(params) # init optimizer

for i in range(num_iterations):
opt_state = optimizer.init(params) # init optimizer
loss = forward_process(params, fmodel, x, y) # compute loss

grads = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates) # update network parameters


@helpers.parametrize(
lr=[1e-2, 1e-3],
method=['naive', 'forward', 'antithetic'],
sigma=[0.01, 0.1, 1],
)
def test_zero_order_module(lr: float, method: str, sigma: float) -> None:
helpers.seed_everything(42)
input_size = 32
output_size = 1
batch_size = BATCH_SIZE
coef = 0.1
num_iterations = NUM_UPDATES
num_samples = 500

class FcNetWithLoss(
torchopt.nn.ZeroOrderGradientModule, method=method, sigma=sigma, num_samples=num_samples
):
def __init__(self, dim, out):
super().__init__()
self.net = FcNet(dim, out)
self.loss = nn.MSELoss()
self.distribution = torch.distributions.Normal(loc=0, scale=1)

def forward(self, x, y):
return self.loss(self.net(x), y)

def sample(self, sample_shape=torch.Size()):
return self.distribution.sample(sample_shape)

x = torch.randn(batch_size, input_size) * coef
y = torch.randn(input_size) * coef
model_with_loss = FcNetWithLoss(input_size, output_size)

optimizer = torchopt.Adam(model_with_loss.parameters(), lr=lr)

for i in range(num_iterations):
loss = model_with_loss(x, y) # compute loss

optimizer.zero_grad()
loss.backward() # compute gradients
optimizer.step() # update network parameters
3 changes: 2 additions & 1 deletion torchopt/diff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,3 +16,4 @@

from torchopt.diff import implicit, zero_order
from torchopt.diff.implicit import ImplicitMetaGradientModule
from torchopt.diff.zero_order import ZeroOrderGradientModule
7 changes: 4 additions & 3 deletions torchopt/diff/implicit/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,10 @@
# ==============================================================================
"""The base class for differentiable implicit meta-gradient models."""

# Preload to resolve circular references
import torchopt.nn.module # pylint: disable=unused-import
import torchopt.nn.module # preload to resolve circular references
from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule


__all__ = ['ImplicitMetaGradientModule']

del torchopt
6 changes: 4 additions & 2 deletions torchopt/diff/zero_order/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,12 @@
import sys as _sys
from types import ModuleType as _ModuleType

from torchopt.diff.zero_order import nn
from torchopt.diff.zero_order.decorator import zero_order
from torchopt.diff.zero_order.nn import ZeroOrderGradientModule


__all__ = ['zero_order']
__all__ = ['zero_order', 'ZeroOrderGradientModule']


class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods
Expand Down
23 changes: 23 additions & 0 deletions torchopt/diff/zero_order/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The base class for zero-order gradient models."""

import torchopt.nn.module # preload to resolve circular references
from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule


__all__ = ['ZeroOrderGradientModule']

del torchopt
116 changes: 116 additions & 0 deletions torchopt/diff/zero_order/nn/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The base class for zero-order gradient models."""

# pylint: disable=redefined-builtin

import abc
import functools
from typing import Dict, Optional, Sequence, Tuple, Type, Union

import torch
import torch.nn as nn

from torchopt import pytree
from torchopt.diff.implicit.nn.module import container_context
from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order
from torchopt.typing import Numeric, TupleOfTensors
from torchopt.utils import extract_module_containers


__all__ = ['ZeroOrderGradientModule']


def enable_zero_order_gradients(
cls: Type['ZeroOrderGradientModule'],
method: Method = 'naive',
num_samples: int = 1,
sigma: Numeric = 1.0,
) -> Type['ZeroOrderGradientModule']:
"""Enable zero-order gradient estimation for the :func:`forward` method."""
cls_forward = cls.forward
if getattr(cls_forward, '__zero_order_gradients_enabled__', False):
raise TypeError(
'Zero-order gradient estimation is already enabled for the `forward` method.'
)

@functools.wraps(cls_forward)
def wrapped( # pylint: disable=too-many-locals
self: 'ZeroOrderGradientModule', *input, **kwargs
) -> torch.Tensor:
"""Do the forward pass calculation."""
params_containers = extract_module_containers(self, with_buffers=False)[0]

flat_params: TupleOfTensors
flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(
params_containers # type: ignore[arg-type]
)

@zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma)
def forward_fn(
__flat_params: TupleOfTensors, # pylint: disable=unused-argument
*input,
**kwargs,
) -> torch.Tensor:
flat_grad_tracking_params = __flat_params
grad_tracking_params_containers: Tuple[
Dict[str, Optional[torch.Tensor]], ...
] = pytree.tree_unflatten( # type: ignore[assignment]
params_containers_treespec, flat_grad_tracking_params
)

with container_context(
params_containers,
grad_tracking_params_containers,
):
return cls_forward(self, *input, **kwargs)

return forward_fn(flat_params, *input, **kwargs)

wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined]
cls.forward = wrapped # type: ignore[assignment]
return cls


class ZeroOrderGradientModule(nn.Module, Samplable):
"""The base class for zero-order gradient models."""

def __init_subclass__( # pylint: disable=arguments-differ
cls,
method: Method = 'naive',
num_samples: int = 1,
sigma: Numeric = 1.0,
) -> None:
"""Validate and initialize the subclass."""
super().__init_subclass__()
enable_zero_order_gradients(
cls,
method=method,
num_samples=num_samples,
sigma=sigma,
)

@abc.abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
"""Do the forward pass of the model."""
raise NotImplementedError

@abc.abstractmethod
def sample(
self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument
) -> Union[torch.Tensor, Sequence[Numeric]]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
raise NotImplementedError
Loading