|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import unittest
|
11 |
| -from typing import Dict, List, Union |
12 | 11 | from unittest.mock import MagicMock, patch
|
13 | 12 |
|
14 | 13 | import torch
|
15 | 14 | from torch.autograd import Variable
|
16 |
| -from torch.distributed import ProcessGroup |
17 |
| -from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard |
18 |
| -from torch.testing._internal.common_utils import ( |
19 |
| - instantiate_parametrized_tests, |
20 |
| - parametrize, |
21 |
| -) |
22 |
| -from torch.testing._internal.distributed._tensor.common_dtensor import ( |
23 |
| - DTensorTestBase, |
24 |
| - with_comms, |
25 |
| -) |
26 | 15 | from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
|
27 | 16 | from torchrec.optim.test_utils import DummyKeyedOptimizer
|
28 | 17 |
|
@@ -240,202 +229,3 @@ def test_clip_no_gradients_norm_meta_device(
|
240 | 229 | gradient_clipping_optimizer.step()
|
241 | 230 |
|
242 | 231 | mock_clip_grad_norm.assert_not_called()
|
243 |
| - |
244 |
| - |
245 |
| -@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") |
246 |
| -@instantiate_parametrized_tests |
247 |
| -class TestGradientClippingDTensor(DTensorTestBase): |
248 |
| - """No tests for Replicated DTensors as handled prior to GradientClippingOptimizer""" |
249 |
| - |
250 |
| - def _get_params_to_pg( |
251 |
| - self, params: List[DTensor] |
252 |
| - ) -> Dict[DTensor, List[ProcessGroup]]: |
253 |
| - return {param: [param.device_mesh.get_group()] for param in params} |
254 |
| - |
255 |
| - @with_comms |
256 |
| - @parametrize("norm_type", ("inf", 1, 2)) |
257 |
| - def test_tensor_and_sharded_dtensor_clip_all_gradients_norm( |
258 |
| - self, norm_type: Union[float, str] |
259 |
| - ) -> None: |
260 |
| - """ |
261 |
| - Test to ensure that the gradient clipping optimizer clips gradients |
262 |
| - correctly with mixed sharded DTensor and tensor by comparing gradients to its |
263 |
| - torch.tensor counterpart. |
264 |
| -
|
265 |
| - Note that clipping for DTensor may require communication. |
266 |
| - """ |
267 |
| - |
268 |
| - # data for testing clipping |
269 |
| - data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type) |
270 |
| - data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type) |
271 |
| - data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) |
272 |
| - data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) |
273 |
| - |
274 |
| - # create gradient clipping optimizer containing no dtensor for reference |
275 |
| - ref_param_1 = torch.nn.Parameter(data_1.clone()) |
276 |
| - ref_param_2 = torch.nn.Parameter(data_2.clone()) |
277 |
| - ref_param_1.grad = data_1_grad.clone() |
278 |
| - ref_param_2.grad = data_2_grad.clone() |
279 |
| - ref_keyed_optimizer = DummyKeyedOptimizer( |
280 |
| - params={"param_1": ref_param_1, "param_2": ref_param_2}, |
281 |
| - state={}, |
282 |
| - param_groups=[{"params": [ref_param_1, ref_param_2]}], |
283 |
| - ) |
284 |
| - ref_gradient_clipping_optimizer = GradientClippingOptimizer( |
285 |
| - optimizer=ref_keyed_optimizer, |
286 |
| - clipping=GradientClipping.NORM, |
287 |
| - max_gradient=10.0, |
288 |
| - norm_type=norm_type, |
289 |
| - ) |
290 |
| - ref_gradient_clipping_optimizer.step() |
291 |
| - |
292 |
| - # create gradient clipping optimizer containing a DTensor and a tensor |
293 |
| - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) |
294 |
| - param_1 = distribute_tensor( |
295 |
| - tensor=torch.tensor( |
296 |
| - data_1.clone(), requires_grad=True, device=self.device_type |
297 |
| - ), |
298 |
| - device_mesh=device_mesh, |
299 |
| - placements=[Shard(0)], |
300 |
| - ) |
301 |
| - param_2 = torch.tensor( |
302 |
| - data_2.clone(), requires_grad=True, device=self.device_type |
303 |
| - ) |
304 |
| - param_1.grad = distribute_tensor( |
305 |
| - tensor=data_1_grad.clone(), |
306 |
| - device_mesh=device_mesh, |
307 |
| - placements=[Shard(0)], |
308 |
| - ) |
309 |
| - param_2.grad = data_2_grad.clone() |
310 |
| - param_to_pgs = self._get_params_to_pg([param_1]) |
311 |
| - keyed_optimizer = DummyKeyedOptimizer( |
312 |
| - params={"dtensor_param_1": param_1, "dtensor_param_2": param_2}, |
313 |
| - state={}, |
314 |
| - param_groups=[{"params": [param_1, param_2]}], |
315 |
| - ) |
316 |
| - gradient_clipping_optimizer = GradientClippingOptimizer( |
317 |
| - optimizer=keyed_optimizer, |
318 |
| - clipping=GradientClipping.NORM, |
319 |
| - max_gradient=10.0, |
320 |
| - norm_type=norm_type, |
321 |
| - enable_global_grad_clip=True, |
322 |
| - param_to_pgs=param_to_pgs, # pyre-ignore[6] |
323 |
| - ) |
324 |
| - gradient_clipping_optimizer.step() |
325 |
| - |
326 |
| - for param_group, ref_param_group in zip( |
327 |
| - gradient_clipping_optimizer.param_groups, |
328 |
| - ref_gradient_clipping_optimizer.param_groups, |
329 |
| - strict=True, |
330 |
| - ): |
331 |
| - for param, ref_param in zip( |
332 |
| - param_group["params"], ref_param_group["params"], strict=True |
333 |
| - ): |
334 |
| - param_grad = ( |
335 |
| - param.grad.full_tensor() # pyre-ignore[16] |
336 |
| - if isinstance(param, DTensor) |
337 |
| - else param.grad |
338 |
| - ) |
339 |
| - self.assertEqual( |
340 |
| - param_grad, |
341 |
| - ref_param.grad, |
342 |
| - f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}", |
343 |
| - ) |
344 |
| - |
345 |
| - @with_comms |
346 |
| - @parametrize("norm_type", ("inf", 1, 2)) |
347 |
| - def test_multiple_sharded_dtensors_clip_all_gradients_norm( |
348 |
| - self, norm_type: Union[float, str] |
349 |
| - ) -> None: |
350 |
| - """ |
351 |
| - Test to ensure that the gradient clipping optimizer clips gradients |
352 |
| - correctly with multiple sharded DTensors by comparing gradients to their |
353 |
| - torch.tensor counterpart. |
354 |
| -
|
355 |
| - Note that clipping for DTensor may require communication. |
356 |
| - """ |
357 |
| - |
358 |
| - # data for testing clipping |
359 |
| - data_1 = torch.tensor([1.0, 2.0, 3.0], device=self.device_type) |
360 |
| - data_2 = torch.tensor([4.0, 5.0, 6.0], device=self.device_type) |
361 |
| - data_1_grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) |
362 |
| - data_2_grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) |
363 |
| - |
364 |
| - # create gradient clipping optimizer containing no dtensor for reference |
365 |
| - ref_param_1 = torch.nn.Parameter(data_1.clone()) |
366 |
| - ref_param_2 = torch.nn.Parameter(data_2.clone()) |
367 |
| - ref_param_1.grad = data_1_grad.clone() |
368 |
| - ref_param_2.grad = data_2_grad.clone() |
369 |
| - ref_keyed_optimizer = DummyKeyedOptimizer( |
370 |
| - params={"param_1": ref_param_1, "param_2": ref_param_2}, |
371 |
| - state={}, |
372 |
| - param_groups=[{"params": [ref_param_1, ref_param_2]}], |
373 |
| - ) |
374 |
| - ref_gradient_clipping_optimizer = GradientClippingOptimizer( |
375 |
| - optimizer=ref_keyed_optimizer, |
376 |
| - clipping=GradientClipping.NORM, |
377 |
| - max_gradient=10.0, |
378 |
| - norm_type=norm_type, |
379 |
| - ) |
380 |
| - ref_gradient_clipping_optimizer.step() |
381 |
| - |
382 |
| - # create gradient clipping optimizer containing 2 DTensors |
383 |
| - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) |
384 |
| - param_1 = distribute_tensor( |
385 |
| - tensor=torch.tensor( |
386 |
| - data_1.clone(), requires_grad=True, device=self.device_type |
387 |
| - ), |
388 |
| - device_mesh=device_mesh, |
389 |
| - placements=[Shard(0)], |
390 |
| - ) |
391 |
| - param_2 = distribute_tensor( |
392 |
| - tensor=torch.tensor( |
393 |
| - data_2.clone(), requires_grad=True, device=self.device_type |
394 |
| - ), |
395 |
| - device_mesh=device_mesh, |
396 |
| - placements=[Shard(0)], |
397 |
| - ) |
398 |
| - param_1.grad = distribute_tensor( |
399 |
| - tensor=data_1_grad.clone(), |
400 |
| - device_mesh=device_mesh, |
401 |
| - placements=[Shard(0)], |
402 |
| - ) |
403 |
| - param_2.grad = distribute_tensor( |
404 |
| - tensor=data_2_grad.clone(), |
405 |
| - device_mesh=device_mesh, |
406 |
| - placements=[Shard(0)], |
407 |
| - ) |
408 |
| - param_to_pgs = self._get_params_to_pg([param_1, param_2]) |
409 |
| - keyed_optimizer = DummyKeyedOptimizer( |
410 |
| - params={"dtensor_param_1": param_1, "dtensor_param_2": param_2}, |
411 |
| - state={}, |
412 |
| - param_groups=[{"params": [param_1, param_2]}], |
413 |
| - ) |
414 |
| - gradient_clipping_optimizer = GradientClippingOptimizer( |
415 |
| - optimizer=keyed_optimizer, |
416 |
| - clipping=GradientClipping.NORM, |
417 |
| - max_gradient=10.0, |
418 |
| - norm_type=norm_type, |
419 |
| - enable_global_grad_clip=True, |
420 |
| - param_to_pgs=param_to_pgs, # pyre-ignore[6] |
421 |
| - ) |
422 |
| - gradient_clipping_optimizer.step() |
423 |
| - |
424 |
| - for param_group, ref_param_group in zip( |
425 |
| - gradient_clipping_optimizer.param_groups, |
426 |
| - ref_gradient_clipping_optimizer.param_groups, |
427 |
| - strict=True, |
428 |
| - ): |
429 |
| - for param, ref_param in zip( |
430 |
| - param_group["params"], ref_param_group["params"], strict=True |
431 |
| - ): |
432 |
| - param_grad = ( |
433 |
| - param.grad.full_tensor() # pyre-ignore[16] |
434 |
| - if isinstance(param, DTensor) |
435 |
| - else param.grad |
436 |
| - ) |
437 |
| - self.assertEqual( |
438 |
| - param_grad, |
439 |
| - ref_param.grad, |
440 |
| - f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}", |
441 |
| - ) |
0 commit comments