|
25 | 25 | from torch.testing import make_tensor |
26 | 26 | from torch.testing._comparison import TensorLikePair |
27 | 27 | from torch.testing._internal.common_dtype import get_all_dtypes, integral_types |
| 28 | +import torch.mps |
28 | 29 | import torch.backends.mps |
29 | 30 | from torch.distributions import Uniform, Exponential |
30 | 31 | from functools import partial |
@@ -5741,6 +5742,45 @@ def test_mps_generator(self): |
5741 | 5742 | mps_x = torch.randn(5, device='mps', generator=g_mps) |
5742 | 5743 | self.assertEqual(mps_x, mps_y) |
5743 | 5744 |
|
| 5745 | + def test_default_mps_generator(self): |
| 5746 | + # manual seeding on the "default" MPS generator using |
| 5747 | + # the global torch.manual_seed() |
| 5748 | + torch.manual_seed(230) |
| 5749 | + mps_x = torch.randn(5, device='mps') |
| 5750 | + # manual seeding using torch.mps.manual_seed() |
| 5751 | + # which should set the "default" MPS generator |
| 5752 | + # like the global torch.manual_seed() |
| 5753 | + torch.mps.manual_seed(230) |
| 5754 | + mps_y = torch.randn(5, device='mps') |
| 5755 | + # seed values were the same, so the random tensor contents should match |
| 5756 | + self.assertEqual(mps_x, mps_y) |
| 5757 | + |
| 5758 | + # save the default generator's state to restore it later |
| 5759 | + g_state = torch.mps.get_rng_state() |
| 5760 | + |
| 5761 | + # generate random numbers without seeding |
| 5762 | + mps_x = torch.randn(5, device='mps') |
| 5763 | + # in this case, the random results must differ from the last generated random results |
| 5764 | + self.assertNotEqual(mps_x, mps_y) |
| 5765 | + |
| 5766 | + # restore the previously saved state, and the results should match again |
| 5767 | + torch.mps.set_rng_state(g_state) |
| 5768 | + mps_x = torch.randn(5, device='mps') |
| 5769 | + self.assertEqual(mps_x, mps_y) |
| 5770 | + |
| 5771 | + def test_device_synchronize(self): |
| 5772 | + # just running some ops each followed by a synchronize to wait for |
| 5773 | + # MPS stream to finish running each of them |
| 5774 | + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ |
| 5775 | + .to(device='mps', dtype=torch.float) |
| 5776 | + |
| 5777 | + x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) |
| 5778 | + torch.mps.synchronize() |
| 5779 | + x = net1(x) |
| 5780 | + torch.mps.synchronize() |
| 5781 | + x.backward(torch.randn_like(x)) |
| 5782 | + torch.mps.synchronize() |
| 5783 | + |
5744 | 5784 | # Test random_.to and random_.from_int |
5745 | 5785 | def test_random(self): |
5746 | 5786 | def helper(shape, low, high, dtype=torch.int32): |
|
0 commit comments