Skip to content

Commit b4e81fe

Browse files
authored
Docs update and auto_model change (#1197)
* Fixes #1174 - Updated docs - auto_model puts params on device if they are not the device * - Updated docs * Update auto.py
1 parent 31c1dcc commit b4e81fe

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

ignite/distributed/auto.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def auto_model(model: nn.Module) -> nn.Module:
127127
128128
Internally, we perform to following:
129129
130-
- send model to current :meth:`~ignite.distributed.utils.device()`.
131-
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1
130+
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
131+
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
132132
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
133133
134134
Examples:
@@ -139,6 +139,15 @@ def auto_model(model: nn.Module) -> nn.Module:
139139
140140
model = idist.auto_model(model)
141141
142+
In addition with NVidia/Apex, it can be used in the following way:
143+
144+
.. code-block:: python
145+
146+
import ignite.distribted as idist
147+
148+
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
149+
model = idist.auto_model(model)
150+
142151
Args:
143152
model (torch.nn.Module): model to adapt.
144153
@@ -150,7 +159,10 @@ def auto_model(model: nn.Module) -> nn.Module:
150159
"""
151160
logger = setup_logger(__name__ + ".auto_model")
152161

153-
model.to(idist.device())
162+
# Put model's parameters to device if its parameters are not on the device
163+
device = idist.device()
164+
if not all([p.device == device for p in model.parameters()]):
165+
model.to(device)
154166

155167
# distributed data parallel model
156168
if idist.get_world_size() > 1:

0 commit comments

Comments
 (0)