Skip to content

Commit

Permalink
[1/2] Collaborative Strategy (#12842)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored May 5, 2022
1 parent d337374 commit 1a502c0
Show file tree
Hide file tree
Showing 9 changed files with 921 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))


- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))


- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
Expand Down
529 changes: 529 additions & 0 deletions pytorch_lightning/strategies/collaborative.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
_GROUP_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_HYDRA_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group")
_HIVEMIND_AVAILABLE = _package_available("hivemind")
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _package_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class _LRScheduler(_Stateful, Protocol):
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...

def step(self, epoch: Optional[int] = None) -> None:
...


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
Expand All @@ -91,6 +94,9 @@ def __init__(
) -> None:
...

def step(self, metrics: Union[float, int, torch.Tensor], epoch: Optional[int] = None) -> None:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
Expand Down
1 change: 1 addition & 0 deletions requirements/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fairscale>=0.4.5
deepspeed<0.6.0
horovod>=0.21.2,!=0.24.0 # no need to install with [pytorch] as pytorch is already installed
hivemind>=1.0.1; sys_platform == 'linux'
7 changes: 7 additions & 0 deletions tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __new__(
omegaconf: bool = False,
slow: bool = False,
bagua: bool = False,
hivemind: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -111,6 +113,7 @@ def __new__(
omegaconf: Require that omry/omegaconf is installed.
slow: Mark the test as slow, our CI will run it in a separate job.
bagua: Require that BaguaSys/bagua is installed.
hivemind: Require that Hivemind is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
Expand Down Expand Up @@ -231,6 +234,10 @@ def __new__(
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Bagua")

if hivemind:
conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Hivemind")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
Loading

0 comments on commit 1a502c0

Please sign in to comment.