@@ -87,23 +87,22 @@ def spawn(*args: Any, **kwargs: Any) -> None:
87
87
_collective_op_dtype = None # type: Any
88
88
89
89
@staticmethod
90
- def _encode_str (x : str , device : torch .device ) -> torch .Tensor :
91
- # use fix padded size
92
- size = 1024
93
- if len (x ) > size :
94
- warnings .warn (f"Input string size { len (x )} is larger than { size } and thus will be truncated" )
95
- x = x [:size ]
96
-
90
+ def _encode_str (x : str , device : torch .device , size : int ) -> torch .Tensor :
97
91
name = torch .tensor (bytearray (x , "utf-8" )).to (device )
98
92
padded_x = torch .zeros (size + 1 , device = device , dtype = torch .long )
99
93
padded_x [: len (name )] = name
100
94
padded_x [- 1 ] = len (name )
101
- # output is tensor of shape (1, 1025 )
95
+ # output is tensor of shape (1, size + 1 )
102
96
return padded_x .unsqueeze (0 )
103
97
98
+ def _get_max_length (self , x : str , device : torch .device ) -> int :
99
+ size = torch .tensor ([len (x ),], device = device )
100
+ size = self ._do_all_reduce (size , "MAX" )
101
+ return cast (int , size .item ())
102
+
104
103
@staticmethod
105
104
def _decode_str (xs : torch .Tensor ) -> List [str ]:
106
- # xs.shape = (n, 1025 ), e.g. (world_size, 1025 )
105
+ # xs.shape = (n, size + 1 ), e.g. (world_size, size + 1 )
107
106
out = [bytearray (x [: x [- 1 ]].tolist ()).decode ("utf-8" ) for x in xs ]
108
107
return out
109
108
@@ -144,7 +143,8 @@ def _collective_op(
144
143
tensor = torch .tensor (tensor , device = device , dtype = self ._collective_op_dtype )
145
144
elif isinstance (tensor , str ):
146
145
tensor_to_str = True
147
- tensor = self ._encode_str (tensor , device )
146
+ max_length = self ._get_max_length (tensor , device )
147
+ tensor = self ._encode_str (tensor , device , size = max_length )
148
148
149
149
tensor = self ._apply_op (tensor , device , fn , * args , ** kwargs )
150
150
@@ -176,20 +176,20 @@ def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Un
176
176
rank = self .get_rank ()
177
177
device = self .device ()
178
178
tensor_to_number = tensor_to_str = False
179
- if rank != src :
180
- if isinstance (tensor , Number ):
181
- tensor_to_number = True
182
- tensor = torch .empty (1 , device = self .device (), dtype = torch .float )
183
- elif isinstance (tensor , str ):
184
- tensor_to_str = True
185
- tensor = torch .empty (1 , 1025 , device = self .device (), dtype = torch .long )
186
- else :
187
- if isinstance (tensor , Number ):
188
- tensor_to_number = True
179
+
180
+ if isinstance (tensor , Number ):
181
+ tensor_to_number = True
182
+ if rank != src :
183
+ tensor = torch .empty (1 , device = device , dtype = torch .float )
184
+ else :
189
185
tensor = torch .tensor ([tensor ,], device = device , dtype = torch .float )
190
- elif isinstance (tensor , str ):
191
- tensor_to_str = True
192
- tensor = self ._encode_str (tensor , device )
186
+ elif isinstance (tensor , str ):
187
+ tensor_to_str = True
188
+ max_length = self ._get_max_length (tensor , device )
189
+ if rank != src :
190
+ tensor = torch .empty (1 , max_length + 1 , device = device , dtype = torch .long )
191
+ else :
192
+ tensor = self ._encode_str (tensor , device , size = max_length )
193
193
194
194
tensor = self ._apply_op (tensor , device , self ._do_broadcast , src )
195
195
@@ -201,7 +201,7 @@ def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Un
201
201
return tensor
202
202
203
203
@abstractmethod
204
- def _do_all_reduce (self , tensor : torch .Tensor , op : str = "sum " ) -> torch .Tensor :
204
+ def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM " ) -> torch .Tensor :
205
205
pass
206
206
207
207
@abstractmethod
@@ -271,7 +271,7 @@ def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_Seria
271
271
def spawn (* args : Any , ** kwargs : Any ) -> None :
272
272
raise NotImplementedError ("Serial computation model does not implement spawn method" )
273
273
274
- def all_reduce (self , tensor : Union [torch .Tensor , float ], op : str = "sum " ) -> Union [torch .Tensor , float ]:
274
+ def all_reduce (self , tensor : Union [torch .Tensor , float ], op : str = "SUM " ) -> Union [torch .Tensor , float ]:
275
275
return tensor
276
276
277
277
def all_gather (self , tensor : Union [torch .Tensor , float , str ]) -> Union [torch .Tensor , float , List [float ], List [str ]]:
@@ -282,14 +282,14 @@ def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Ten
282
282
def broadcast (self , tensor : Union [torch .Tensor , float , str ], src : int = 0 ) -> Union [torch .Tensor , float , str ]:
283
283
return tensor
284
284
285
- def _do_all_reduce (self , tensor : torch .Tensor , op : str = "sum " ) -> torch .Tensor :
286
- pass
285
+ def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM " ) -> torch .Tensor :
286
+ return tensor
287
287
288
288
def _do_all_gather (self , tensor : torch .Tensor ) -> torch .Tensor :
289
- pass
289
+ return tensor
290
290
291
291
def _do_broadcast (self , tensor : torch .Tensor , src : int ) -> torch .Tensor :
292
- pass
292
+ return tensor
293
293
294
294
def barrier (self ) -> None :
295
295
pass
0 commit comments