Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] DDP and automatic_optimization=False #4485

Merged
merged 74 commits into from
Nov 10, 2020

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Nov 2, 2020

What does this PR do?

In automatic_optimizaiton=False, if the user returned a detached tensor, it creates a RuntimeError.
Still don't know the reason behind it, but now it catches it and returns a MisConfigurationError.

Also, this PR add _running_manual_optimatization to make sure that if the user returned loss at the end of training_step,
It won't be called and add gradients by doing assert torch.sum(self.layer.weight.grad) == 0

Closes #4444
Closes #4485

  • 2 extra bugs.
  • zero_grad should be called all optimisers in manual_optimization
  • zero_grad should be called when nothing is returned

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Nov 2, 2020

Codecov Report

Merging #4485 (b03b4d7) into master (abf1d4b) will increase coverage by 3%.
The diff coverage is 97%.

@@           Coverage Diff           @@
##           master   #4485    +/-   ##
=======================================
+ Coverage      90%     93%    +3%     
=======================================
  Files         116     116            
  Lines        8858    8883    +25     
=======================================
+ Hits         8012    8275   +263     
+ Misses        846     608   -238     

@@ -1094,7 +1095,9 @@ def training_step(...):
self._verify_is_manual_optimization('manual_backward')

# backward
self._running_manual_optim = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be easier to add an optional arg to backward? or do we need the state _running_manual_optim somewhere else as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question !
I thought about it, but it would mean broadcasting a new parameter into training_loop, accelerator, etc...
I though the simpler, the better. And now, it defines a scope around manual_optimization if we need it.
Maybe an actual context manager would be cleaner there.
@SeanNaren @Borda Any thoughts ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could reproduce the bug with ddp but not locally. Need more investigation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes am confused why this bug only appears in ddp. Could you confirm you don't see the behaviour with ddp_cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more investigation :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @s-rog,
I have tried to pass it as a params, and it started to be ugly.
I preferred to revert back to my first idea.
I hope you understand 👍

Best regards,
Thomas Chaton

@pep8speaks
Copy link

pep8speaks commented Nov 3, 2020

Hello @tchaton! Thanks for updating this PR.

Line 1089:121: E501 line too long (128 > 120 characters)
Line 1109:79: E231 missing whitespace after ':'
Line 1114:121: E501 line too long (129 > 120 characters)

Comment last updated at 2020-11-10 19:19:10 UTC

@tchaton tchaton changed the title [bug-fix] DDP and automatic_optimization=False [WIP][bug-fix] DDP and automatic_optimization=False Nov 3, 2020
@tchaton
Copy link
Contributor Author

tchaton commented Nov 3, 2020

    trainer = Trainer(
        max_epochs=1,
        default_root_dir=tmpdir,
        limit_train_batches=10,
        limit_test_batches=0,
        limit_val_batches=0,
        automatic_optimization=False,
        precision=16,
        amp_backend='native',    
        accelerator="ddp",
        gpus=2,
    )
tensors = (tensor([nan], device='cuda:1', grad_fn=<MulBackward0>),), grad_tensors = (tensor([1.], device='cuda:1'),), retain_graph = False, create_graph = False, grad_variables = None

    def backward(
        tensors: _TensorOrTensors,
        grad_tensors: Optional[_TensorOrTensors] = None,
        retain_graph: Optional[bool] = None,
        create_graph: bool = False,
        grad_variables: Optional[_TensorOrTensors] = None,
    ) -> None:
        r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
    
        The graph is differentiated using the chain rule. If any of ``tensors``
        are non-scalar (i.e. their data has more than one element) and require
        gradient, then the Jacobian-vector product would be computed, in this
        case the function additionally requires specifying ``grad_tensors``.
        It should be a sequence of matching length, that contains the "vector"
        in the Jacobian-vector product, usually the gradient of the differentiated
        function w.r.t. corresponding tensors (``None`` is an acceptable value for
        all tensors that don't need gradient tensors).
    
        This function accumulates gradients in the leaves - you might need to zero
        ``.grad`` attributes or set them to ``None`` before calling it.
        See :ref:`Default gradient layouts<default-grad-layouts>`
        for details on the memory layout of accumulated gradients.
    
        .. note::
            Using this method with ``create_graph=True`` will create a reference cycle
            between the parameter and its gradient which can cause a memory leak.
            We recommend using ``autograd.grad`` when creating the graph to avoid this.
            If you have to use this function, make sure to reset the ``.grad`` fields of your
            parameters to ``None`` after use to break the cycle and avoid the leak.
    
        Arguments:
            tensors (sequence of Tensor): Tensors of which the derivative will be
                computed.
            grad_tensors (sequence of (Tensor or None)): The "vector" in the Jacobian-vector
                product, usually gradients w.r.t. each element of corresponding tensors.
                None values can be specified for scalar Tensors or ones that don't require
                grad. If a None value would be acceptable for all grad_tensors, then this
                argument is optional.
            retain_graph (bool, optional): If ``False``, the graph used to compute the grad
                will be freed. Note that in nearly all cases setting this option to ``True``
                is not needed and often can be worked around in a much more efficient
                way. Defaults to the value of ``create_graph``.
            create_graph (bool, optional): If ``True``, graph of the derivative will
                be constructed, allowing to compute higher order derivative products.
                Defaults to ``False``.
        """
        if grad_variables is not None:
            warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
            if grad_tensors is None:
                grad_tensors = grad_variables
            else:
                raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                                   "arguments both passed to backward(). Please only "
                                   "use 'grad_tensors'.")
    
        tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    
        if grad_tensors is None:
            grad_tensors = [None] * len(tensors)
        elif isinstance(grad_tensors, torch.Tensor):
    def backward(
            grad_tensors = [grad_tensors]
        else:
        tensors: _TensorOrTensors,
            grad_tensors = list(grad_tensors)
        grad_tensors: Optional[_TensorOrTensors] = None,
    
        retain_graph: Optional[bool] = None,
        grad_tensors = _make_grads(tensors, grad_tensors)
        create_graph: bool = False,
        if retain_graph is None:
        grad_variables: Optional[_TensorOrTensors] = None,
            retain_graph = create_graph
    ) -> None:
    
        r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
        Variable._execution_engine.run_backward(
    
            tensors, grad_tensors, retain_graph, create_graph,
        The graph is differentiated using the chain rule. If any of ``tensors``
>           allow_unreachable=True)  # allow_unreachable flag
        are non-scalar (i.e. their data has more than one element) and require
        gradient, then the Jacobian-vector product would be computed, in this
        case the function additionally requires specifying ``grad_tensors``.
E       RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.
        It should be a sequence of matching length, that contains the "vector"
        in the Jacobian-vector product, usually the gradient of the differentiated
E       Exception raised from mark_variable_ready at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:453 (most recent call first):
        function w.r.t. corresponding tensors (``None`` is an acceptable value for
        all tensors that don't need gradient tensors).
E       frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f2f744b01e2 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
    
E       frame #1: c10d::Reducer::mark_variable_ready(c10d::Reducer::VariableIndex) + 0x4b7 (0x7f2fc05849b7 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
        This function accumulates gradients in the leaves - you might need to zero
        ``.grad`` attributes or set them to ``None`` before calling it.
E       frame #2: c10d::Reducer::autograd_hook(c10d::Reducer::VariableIndex) + 0xef (0x7f2fc05870df in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
        See :ref:`Default gradient layouts<default-grad-layouts>`
        for details on the memory layout of accumulated gradients.
E       frame #3: <unknown function> + 0xa91246 (0x7f2fc0587246 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
    
        .. note::
E       frame #4: <unknown function> + 0xa95d16 (0x7f2fc058bd16 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
            Using this method with ``create_graph=True`` will create a reference cycle
            between the parameter and its gradient which can cause a memory leak.
E       frame #5: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x4dd (0x7f2fb28894dd in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
            We recommend using ``autograd.grad`` when creating the graph to avoid this.
            If you have to use this function, make sure to reset the ``.grad`` fields of your
E       frame #6: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7f2fb288afa1 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
            parameters to ``None`` after use to break the cycle and avoid the leak.
    
E       frame #7: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7f2fb2883119 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
        Arguments:
            tensors (sequence of Tensor): Tensors of which the derivative will be
E       frame #8: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7f2fc0023dea in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
                computed.
            grad_tensors (sequence of (Tensor or None)): The "vector" in the Jacobian-vector
E       frame #9: <unknown function> + 0xbd6df (0x7f2fc0f686df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
                product, usually gradients w.r.t. each element of corresponding tensors.
                None values can be specified for scalar Tensors or ones that don't require
E       frame #10: <unknown function> + 0x76db (0x7f30060bf6db in /lib/x86_64-linux-gnu/libpthread.so.0)
                grad. If a None value would be acceptable for all grad_tensors, then this
                argument is optional.
E       frame #11: clone + 0x3f (0x7f3005de8a3f in /lib/x86_64-linux-gnu/libc.so.6)
            retain_graph (bool, optional): If ``False``, the graph used to compute the grad
                will be freed. Note that in nearly all cases setting this option to ``True``

                is not needed and often can be worked around in a much more efficient
                way. Defaults to the value of ``create_graph``.
/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py:127: RuntimeError
            create_graph (bool, optional): If ``True``, graph of the derivative will
                be constructed, allowing to compute higher order derivative products.
                Defaults to ``False``.
        """
------------------------------------------------------------------------------------- Captured log call -------------------------------------------------------------------------------------
        if grad_variables is not None:
INFO     lightning:distributed.py:49 GPU available: True, used: True
INFO     lightning:distributed.py:49 TPU available: False, using: 0 TPU cores
INFO     lightning:accelerator_connector.py:402 LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO     lightning:precision_connector.py:57 Using native 16bit precision.
WARNING  lightning:tensorboard.py:241 Missing logger folder: /tmp/pytest-of-jovyan/pytest-4/test_automatic_optimization_fa0/lightning_logs
INFO     lightning:accelerator.py:206 initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
INFO     lightning:lightning.py:1320 
  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66            warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")

            if grad_tensors is None:
                grad_tensors = grad_variables
            else:
                raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                                   "arguments both passed to backward(). Please only "
                                   "use 'grad_tensors'.")
    
        tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    
        if grad_tensors is None:
===================================================================================== warnings summary ======================================================================================
            grad_tensors = [None] * len(tensors)
        elif isinstance(grad_tensors, torch.Tensor):
tests/trainer/optimization/test_manual_optimization.py::test_automatic_optimization_false
            grad_tensors = [grad_tensors]
        else:
  /home/jovyan/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
    warnings.warn(*args, **kwargs)            grad_tensors = list(grad_tensors)

    

        grad_tensors = _make_grads(tensors, grad_tensors)
tests/trainer/optimization/test_manual_optimization.py::test_automatic_optimization_false
        if retain_graph is None:
            retain_graph = create_graph
  /home/jovyan/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
    warnings.warn(*args, **kwargs)    

        Variable._execution_engine.run_backward(

            tensors, grad_tensors, retain_graph, create_graph,
-- Docs: https://docs.pytest.org/en/stable/warnings.html
>           allow_unreachable=True)  # allow_unreachable flag
E       RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.
E       Exception raised from mark_variable_ready at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:453 (most recent call first):
E       frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f03f127f1e2 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
E       frame #1: c10d::Reducer::mark_variable_ready(c10d::Reducer::VariableIndex) + 0x4b7 (0x7f043d3539b7 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
E       frame #2: c10d::Reducer::autograd_hook(c10d::Reducer::VariableIndex) + 0xef (0x7f043d3560df in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
E       frame #3: <unknown function> + 0xa91246 (0x7f043d356246 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
E       frame #4: <unknown function> + 0xa95d16 (0x7f043d35ad16 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
E       frame #5: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x4dd (0x7f042f6584dd in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
E       frame #6: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7f042f659fa1 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
E       frame #7: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7f042f652119 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
E       frame #8: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7f043cdf2dea in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
E       frame #9: <unknown function> + 0xbd6df (0x7f043dd376df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
E       frame #10: <unknown function> + 0x76db (0x7f0482e8e6db in /lib/x86_64-linux-gnu/libpthread.so.0)
E       frame #11: clone + 0x3f (0x7f0482bb7a3f in /lib/x86_64-linux-gnu/libc.so.6)

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py:127: RuntimeError
------------------------------------------------------------------------------------- Captured log call -------------------------------------------------------------------------------------
================================================================================== short test summary info ==================================================================================
INFO     lightning:accelerator_connector.py:402 LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
INFO     lightning:precision_connector.py:57 Using native 16bit precision.
INFO     lightning:accelerator.py:206 initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2FAILED tests/trainer/optimization/test_manual_optimization.py::test_automatic_optimization_false - RuntimeError: Expected to mark a variable ready only once. This error is caused by one ...

=============================================================================== 1 failed, 2 warnings in 5.09s ===============================================================================
================================================================================== short test summary info ==================================================================================
FAILED tests/trainer/optimization/test_manual_optimization.py::test_automatic_optimization_false - RuntimeError: Expected to mark a variable ready only once. This error is caused by one ...
===================================================================================== 1 failed in 3.92s =====================================================================================
Exception ignored in: <function tqdm.__del__ at 0x7f2f73634290>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1131, in __del__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1344, in close
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1523, in display
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1134, in __repr__
  File "/opt/conda/lib/python3.7/site-packages/tqdm/std.py", line 1484, in format_dict

@@ -33,6 +33,7 @@ timit_data/
.Python
ide_layouts/
build/
_build/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how this one is created, well we can keep just curious what process did it :]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this one, so when moving doc from build to _build and serving it with sphinx-serve. They don't get added.

CHANGELOG.md Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
Borda and others added 2 commits November 10, 2020 19:42
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
@SeanNaren SeanNaren merged commit 7e08b0d into master Nov 10, 2020
@SeanNaren SeanNaren deleted the bugfix/4444_ddp_and_automatic_optimization=False branch November 10, 2020 19:44
@SeanNaren
Copy link
Contributor

Thanks @tchaton this was a real refactor/fix!

self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
self._running_manual_backward = False

def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure to forward arguments when we do stuff like this?

ie: now the user can't use the args of .step...

so:

def manual_optimizer_step(self, *args, optimizer: Optimizer, force_optimizer_step:bool = False, **kwargs):
   ... # eventually forwards to:
       optimizer.step(*args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @williamFalcon. I will add a PR for it tomorrow.

SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 7e08b0d)
SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 7e08b0d)
SeanNaren pushed a commit that referenced this pull request Nov 11, 2020
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 7e08b0d)
Borda pushed a commit that referenced this pull request Nov 11, 2020
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 7e08b0d)
rohitgr7 pushed a commit that referenced this pull request Nov 21, 2020
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error with DDP and automatic_optimization=False
7 participants