44from typing import Callable , List , Mapping , Optional , Tuple , Union
55
66import torch
7- import torch .distributed as dist
87
98from ignite .distributed .comp_models import (
109 _SerialModel ,
4645_need_to_sync = True
4746
4847
49- def _sync_model_wrapper (func ):
50- @wraps (func )
51- def wrapper (* args , ** kwargs ):
52- if isinstance (_model , _SerialModel ) and _need_to_sync :
53- sync ()
54- return func (* args , ** kwargs )
55-
56- return wrapper
57-
58-
59- def sync ():
48+ def sync (temporary = False ):
6049 """Helper method to force this module to synchronize with current distributed context.
6150 This method should be used when distributed context is manually created or destroyed.
51+
52+ Args:
53+ temporary (bool): If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
54+ This may have performance negative impact.
6255 """
6356 global _model
6457
@@ -67,13 +60,12 @@ def sync():
6760 continue
6861 model = comp_model_cls .create_from_context ()
6962 if model is not None :
70- _model = model
63+ _set_model ( model , temporary = temporary )
7164 return
7265
7366 _model = _SerialModel ()
7467
7568
76- @_sync_model_wrapper
7769def device () -> torch .device :
7870 """Returns current device according to current distributed configuration.
7971
@@ -84,10 +76,12 @@ def device() -> torch.device:
8476 Returns:
8577 torch.device
8678 """
79+ if _need_to_sync and isinstance (_model , _SerialModel ):
80+ sync (temporary = True )
81+
8782 return _model .device ()
8883
8984
90- @_sync_model_wrapper
9185def backend () -> Optional [str ]:
9286 """Returns computation model's backend.
9387
@@ -98,6 +92,9 @@ def backend() -> Optional[str]:
9892 Returns:
9993 str or None
10094 """
95+ if _need_to_sync and isinstance (_model , _SerialModel ):
96+ sync (temporary = True )
97+
10198 return _model .backend ()
10299
103100
@@ -110,7 +107,6 @@ def available_backends() -> Tuple[str]:
110107 return out
111108
112109
113- @_sync_model_wrapper
114110def model_name () -> str :
115111 """Returns distributed configuration name (given by ignite)
116112
@@ -119,51 +115,66 @@ def model_name() -> str:
119115 - `xla-dist` for XLA distributed configuration
120116
121117 """
118+ if _need_to_sync and isinstance (_model , _SerialModel ):
119+ sync (temporary = True )
120+
122121 return _model .name
123122
124123
125- @_sync_model_wrapper
126124def get_world_size () -> int :
127125 """Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
128126 """
127+ if _need_to_sync and isinstance (_model , _SerialModel ):
128+ sync (temporary = True )
129+
129130 return _model .get_world_size ()
130131
131132
132- @_sync_model_wrapper
133133def get_rank () -> int :
134134 """Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
135135 """
136+ if _need_to_sync and isinstance (_model , _SerialModel ):
137+ sync (temporary = True )
138+
136139 return _model .get_rank ()
137140
138141
139- @_sync_model_wrapper
140142def get_local_rank () -> int :
141143 """Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
142144 """
145+ if _need_to_sync and isinstance (_model , _SerialModel ):
146+ sync (temporary = True )
147+
143148 return _model .get_local_rank ()
144149
145150
146- @_sync_model_wrapper
147151def get_nproc_per_node () -> int :
148152 """Returns number of processes (or tasks) per node within current distributed configuration.
149153 Returns 1 if no distributed configuration.
150154 """
155+ if _need_to_sync and isinstance (_model , _SerialModel ):
156+ sync (temporary = True )
157+
151158 return _model .get_nproc_per_node ()
152159
153160
154- @_sync_model_wrapper
155161def get_nnodes () -> int :
156162 """Returns number of nodes within current distributed configuration.
157163 Returns 1 if no distributed configuration.
158164 """
165+ if _need_to_sync and isinstance (_model , _SerialModel ):
166+ sync (temporary = True )
167+
159168 return _model .get_nnodes ()
160169
161170
162- @_sync_model_wrapper
163171def get_node_rank () -> int :
164172 """Returns node rank within current distributed configuration.
165173 Returns 0 if no distributed configuration.
166174 """
175+ if _need_to_sync and isinstance (_model , _SerialModel ):
176+ sync (temporary = True )
177+
167178 return _model .get_node_rank ()
168179
169180
@@ -291,7 +302,6 @@ def train_fn(local_rank, a, b, c, d=12):
291302 )
292303
293304
294- @_sync_model_wrapper
295305def all_reduce (tensor : Union [torch .Tensor , Number ], op : str = "SUM" ) -> Union [torch .Tensor , Number ]:
296306 """Helper method to perform all reduce operation.
297307
@@ -303,10 +313,12 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
303313 torch.Tensor or number
304314
305315 """
316+ if _need_to_sync and isinstance (_model , _SerialModel ):
317+ sync (temporary = True )
318+
306319 return _model .all_reduce (tensor , op )
307320
308321
309- @_sync_model_wrapper
310322def all_gather (tensor : Union [torch .Tensor , Number , str ]) -> Union [torch .Tensor , Number , List [str ]]:
311323 """Helper method to perform all gather operation.
312324
@@ -318,13 +330,18 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
318330 List of strings
319331
320332 """
333+ if _need_to_sync and isinstance (_model , _SerialModel ):
334+ sync (temporary = True )
335+
321336 return _model .all_gather (tensor )
322337
323338
324- @_sync_model_wrapper
325339def barrier ():
326340 """Helper method to synchronize all processes.
327341 """
342+ if _need_to_sync and isinstance (_model , _SerialModel ):
343+ sync (temporary = True )
344+
328345 _model .barrier ()
329346
330347
@@ -356,11 +373,11 @@ def run(local_rank, *args, **kwargs):
356373 ComputationModel ._ext_local_rank = index
357374
358375
359- def _set_model (model ):
376+ def _set_model (model , temporary = False ):
360377 global _model , _need_to_sync
361378 _model = model
362379 _need_to_sync = True
363- if not isinstance (_model , _SerialModel ):
380+ if not isinstance (_model , _SerialModel ) and not temporary :
364381 _need_to_sync = False
365382
366383
@@ -408,7 +425,7 @@ def train_fn(local_rank, a, b, c):
408425
409426
410427 """
411- if not (has_xla_support or dist . is_available () ):
428+ if not (has_xla_support or has_native_dist_support ):
412429 # nothing to do => serial model
413430 # maybe warn about this
414431 return
0 commit comments