@@ -127,8 +127,8 @@ def auto_model(model: nn.Module) -> nn.Module:
127127
128128 Internally, we perform to following:
129129
130- - send model to current :meth:`~ignite.distributed.utils.device()`.
131- - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1
130+ - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device .
131+ - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
132132 - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
133133
134134 Examples:
@@ -139,6 +139,15 @@ def auto_model(model: nn.Module) -> nn.Module:
139139
140140 model = idist.auto_model(model)
141141
142+ In addition with NVidia/Apex, it can be used in the following way:
143+
144+ .. code-block:: python
145+
146+ import ignite.distribted as idist
147+
148+ model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
149+ model = idist.auto_model(model)
150+
142151 Args:
143152 model (torch.nn.Module): model to adapt.
144153
@@ -150,7 +159,10 @@ def auto_model(model: nn.Module) -> nn.Module:
150159 """
151160 logger = setup_logger (__name__ + ".auto_model" )
152161
153- model .to (idist .device ())
162+ # Put model's parameters to device if its parameters are not on the device
163+ device = idist .device ()
164+ if not all ([p .device == device for p in model .parameters ()]):
165+ model .to (device )
154166
155167 # distributed data parallel model
156168 if idist .get_world_size () > 1 :
0 commit comments