Skip to content

Commit

Permalink
style: use postponed evaluation of annotations and update doctring st…
Browse files Browse the repository at this point in the history
…yle (#135)
  • Loading branch information
XuehaiPan authored Feb 15, 2023
1 parent c67476b commit e2157de
Show file tree
Hide file tree
Showing 66 changed files with 1,165 additions and 1,021 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135).
- Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133).

### Fixed
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
<a href="https://codecov.io/gh/metaopt/torchopt">![CodeCov](https://img.shields.io/codecov/c/gh/metaopt/torchopt)</a>
<a href="https://torchopt.readthedocs.io">![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs)</a>
<a href="https://pepy.tech/project/torchopt">![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads)</a>
<a href="https://github.com/metaopt/torchopt/stargazers">![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github)</a>
<a href="https://github.com/metaopt/torchopt/blob/HEAD/LICENSE">![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCIgd2lkdGg9IjI0IiBoZWlnaHQ9IjI0IiBmaWxsPSIjZmZmZmZmIj48cGF0aCBmaWxsLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xMi43NSAyLjc1YS43NS43NSAwIDAwLTEuNSAwVjQuNUg5LjI3NmExLjc1IDEuNzUgMCAwMC0uOTg1LjMwM0w2LjU5NiA1Ljk1N0EuMjUuMjUgMCAwMTYuNDU1IDZIMi4zNTNhLjc1Ljc1IDAgMTAwIDEuNUgzLjkzTC41NjMgMTUuMThhLjc2Mi43NjIgMCAwMC4yMS44OGMuMDguMDY0LjE2MS4xMjUuMzA5LjIyMS4xODYuMTIxLjQ1Mi4yNzguNzkyLjQzMy42OC4zMTEgMS42NjIuNjIgMi44NzYuNjJhNi45MTkgNi45MTkgMCAwMDIuODc2LS42MmMuMzQtLjE1NS42MDYtLjMxMi43OTItLjQzMy4xNS0uMDk3LjIzLS4xNTguMzEtLjIyM2EuNzUuNzUgMCAwMC4yMDktLjg3OEw1LjU2OSA3LjVoLjg4NmMuMzUxIDAgLjY5NC0uMTA2Ljk4NC0uMzAzbDEuNjk2LTEuMTU0QS4yNS4yNSAwIDAxOS4yNzUgNmgxLjk3NXYxNC41SDYuNzYzYS43NS43NSAwIDAwMCAxLjVoMTAuNDc0YS43NS43NSAwIDAwMC0xLjVIMTIuNzVWNmgxLjk3NGMuMDUgMCAuMS4wMTUuMTQuMDQzbDEuNjk3IDEuMTU0Yy4yOS4xOTcuNjMzLjMwMy45ODQuMzAzaC44ODZsLTMuMzY4IDcuNjhhLjc1Ljc1IDAgMDAuMjMuODk2Yy4wMTIuMDA5IDAgMCAuMDAyIDBhMy4xNTQgMy4xNTQgMCAwMC4zMS4yMDZjLjE4NS4xMTIuNDUuMjU2Ljc5LjRhNy4zNDMgNy4zNDMgMCAwMDIuODU1LjU2OCA3LjM0MyA3LjM0MyAwIDAwMi44NTYtLjU2OWMuMzM4LS4xNDMuNjA0LS4yODcuNzktLjM5OWEzLjUgMy41IDAgMDAuMzEtLjIwNi43NS43NSAwIDAwLjIzLS44OTZMMjAuMDcgNy41aDEuNTc4YS43NS43NSAwIDAwMC0xLjVoLTQuMTAyYS4yNS4yNSAwIDAxLS4xNC0uMDQzbC0xLjY5Ny0xLjE1NGExLjc1IDEuNzUgMCAwMC0uOTg0LS4zMDNIMTIuNzVWMi43NXpNMi4xOTMgMTUuMTk4YTUuNDE4IDUuNDE4IDAgMDAyLjU1Ny42MzUgNS40MTggNS40MTggMCAwMDIuNTU3LS42MzVMNC43NSA5LjM2OGwtMi41NTcgNS44M3ptMTQuNTEtLjAyNGMuMDgyLjA0LjE3NC4wODMuMjc1LjEyNi41My4yMjMgMS4zMDUuNDUgMi4yNzIuNDVhNS44NDYgNS44NDYgMCAwMDIuNTQ3LS41NzZMMTkuMjUgOS4zNjdsLTIuNTQ3IDUuODA3eiI+PC9wYXRoPjwvc3ZnPgo=)</a>
</div>

Expand Down
109 changes: 109 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,115 @@ Chain
.. autofunction:: chain


Distributed Utilities
=====================

.. currentmodule:: torchopt.distributed

Initialization and Synchronization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

auto_init_rpc
barrier

.. autofunction:: auto_init_rpc
.. autofunction:: barrier

Process group information
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

get_world_info
get_world_rank
get_rank
get_world_size
get_local_rank
get_local_world_size
get_worker_id

.. autofunction:: get_world_info
.. autofunction:: get_world_rank
.. autofunction:: get_rank
.. autofunction:: get_world_size
.. autofunction:: get_local_rank
.. autofunction:: get_local_world_size
.. autofunction:: get_worker_id

Worker selection
~~~~~~~~~~~~~~~~

.. autosummary::

on_rank
not_on_rank
rank_zero_only
rank_non_zero_only

.. autofunction:: on_rank
.. autofunction:: not_on_rank
.. autofunction:: rank_zero_only
.. autofunction:: rank_non_zero_only

Remote Procedure Call (RPC)
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

remote_async_call
remote_sync_call

.. autofunction:: remote_async_call
.. autofunction:: remote_sync_call

Predefined partitioners and reducers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

dim_partitioner
batch_partitioner
mean_reducer
sum_reducer

.. autofunction:: dim_partitioner
.. autofunction:: batch_partitioner
.. autofunction:: mean_reducer
.. autofunction:: sum_reducer

Function parallelization wrappers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::

parallelize
parallelize_async
parallelize_sync

.. autofunction:: parallelize
.. autofunction:: parallelize_async
.. autofunction:: parallelize_sync

Distributed Autograd
~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchopt.distributed.autograd

.. autosummary::

context
get_gradients
backward
grad

.. autofunction:: context
.. autofunction:: get_gradients
.. autofunction:: backward
.. autofunction:: grad


General Utilities
=================

Expand Down
7 changes: 0 additions & 7 deletions docs/source/distributed/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ Initialization and Synchronization

.. autosummary::


torchopt.distributed.auto_init_rpc
torchopt.distributed.barrier

Expand Down Expand Up @@ -197,7 +196,6 @@ Process group information

.. autosummary::


torchopt.distributed.get_world_info
torchopt.distributed.get_world_rank
torchopt.distributed.get_rank
Expand Down Expand Up @@ -228,7 +226,6 @@ Worker selection

.. autosummary::


torchopt.distributed.on_rank
torchopt.distributed.not_on_rank
torchopt.distributed.rank_zero_only
Expand Down Expand Up @@ -275,7 +272,6 @@ Remote Procedure Call (RPC)

.. autosummary::


torchopt.distributed.remote_async_call
torchopt.distributed.remote_sync_call

Expand Down Expand Up @@ -354,7 +350,6 @@ Predefined partitioners and reducers

.. autosummary::


torchopt.distributed.dim_partitioner
torchopt.distributed.batch_partitioner
torchopt.distributed.mean_reducer
Expand Down Expand Up @@ -439,7 +434,6 @@ Function parallelization wrappers

.. autosummary::


torchopt.distributed.parallelize
torchopt.distributed.parallelize_async
torchopt.distributed.parallelize_sync
Expand Down Expand Up @@ -490,7 +484,6 @@ Distributed Autograd

.. autosummary::


torchopt.distributed.autograd.context
torchopt.distributed.autograd.get_gradients
torchopt.distributed.autograd.backward
Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ issubclass
abc
ABCMeta
subclasscheck
ctx
22 changes: 12 additions & 10 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import copy
import itertools
import os
import random
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable

import numpy as np
import pytest
Expand Down Expand Up @@ -137,7 +139,7 @@ def get_model():
@torch.no_grad()
def get_models(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)

model_base = get_model().to(dtype=dtype)
Expand Down Expand Up @@ -166,12 +168,12 @@ def get_models(

@torch.no_grad()
def assert_model_all_close(
model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]],
model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]],
model_ref: nn.Module,
model_base: nn.Module,
dtype: torch.dtype = torch.float32,
rtol: Optional[float] = None,
atol: Optional[float] = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
if isinstance(model, tuple):
Expand All @@ -194,8 +196,8 @@ def assert_all_close(
actual: torch.Tensor,
expected: torch.Tensor,
base: torch.Tensor = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
if base is not None:
Expand Down Expand Up @@ -223,9 +225,9 @@ def assert_all_close(
def assert_pytree_all_close(
actual: TensorTree,
expected: TensorTree,
base: Optional[TensorTree] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
base: TensorTree | None = None,
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = False,
) -> None:
actual_leaves, actual_treespec = pytree.tree_flatten(actual)
Expand Down
14 changes: 8 additions & 6 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Tuple
from __future__ import annotations

from typing import Callable

import functorch
import pytest
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_sgd(
def test_adam(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -177,7 +179,7 @@ def test_maml_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -263,7 +265,7 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op):
def test_adamw(
dtype: torch.dtype,
lr: float,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down Expand Up @@ -333,8 +335,8 @@ def test_adamw(
def test_adam_accelerated_cuda(
dtype: torch.dtype,
lr: float,
optimizers: Tuple[Callable, torch.optim.Optimizer],
betas: Tuple[float, float],
optimizers: tuple[Callable, torch.optim.Optimizer],
betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import copy
from collections import OrderedDict
from types import FunctionType
from typing import Tuple

import functorch
import jax
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(self, x):
return self.fc(x)


def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]:
def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]:
helpers.seed_everything(seed=42)

def func(params, x):
Expand All @@ -73,7 +74,7 @@ def func(params, x):
@torch.no_grad()
def get_model_torch(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
) -> Tuple[nn.Module, data.DataLoader]:
) -> tuple[nn.Module, data.DataLoader]:
helpers.seed_everything(seed=42)

model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_meta_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple
from __future__ import annotations

import torch
import torch.nn.functional as F
Expand All @@ -40,7 +40,7 @@ def test_maml_meta_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
betas: Tuple[float, float],
betas: tuple[float, float],
eps: float,
eps_root: float,
weight_decay: float,
Expand Down
Loading

0 comments on commit e2157de

Please sign in to comment.