|
8 | 8 |
|
9 | 9 |
|
10 | 10 | def test_seed_behavior(): |
11 | | - # Test with seed=None |
12 | | - Platform.seed_everything(None) |
| 11 | + # Test with a specific seed |
| 12 | + Platform.seed_everything(42) |
13 | 13 | random_value_1 = random.randint(0, 100) |
14 | 14 | np_random_value_1 = np.random.randint(0, 100) |
15 | 15 | torch_random_value_1 = torch.randint(0, 100, (1, )).item() |
16 | 16 |
|
17 | | - Platform.seed_everything(None) |
| 17 | + Platform.seed_everything(42) |
18 | 18 | random_value_2 = random.randint(0, 100) |
19 | 19 | np_random_value_2 = np.random.randint(0, 100) |
20 | 20 | torch_random_value_2 = torch.randint(0, 100, (1, )).item() |
21 | 21 |
|
22 | | - assert random_value_1 != random_value_2 |
23 | | - assert np_random_value_1 != np_random_value_2 |
24 | | - assert torch_random_value_1 != torch_random_value_2 |
25 | | - |
26 | | - # Test with a specific seed |
27 | | - Platform.seed_everything(42) |
28 | | - random_value_3 = random.randint(0, 100) |
29 | | - np_random_value_3 = np.random.randint(0, 100) |
30 | | - torch_random_value_3 = torch.randint(0, 100, (1, )).item() |
31 | | - |
32 | | - Platform.seed_everything(42) |
33 | | - random_value_4 = random.randint(0, 100) |
34 | | - np_random_value_4 = np.random.randint(0, 100) |
35 | | - torch_random_value_4 = torch.randint(0, 100, (1, )).item() |
36 | | - |
37 | | - assert random_value_3 == random_value_4 |
38 | | - assert np_random_value_3 == np_random_value_4 |
39 | | - assert torch_random_value_3 == torch_random_value_4 |
| 22 | + assert random_value_1 == random_value_2 |
| 23 | + assert np_random_value_1 == np_random_value_2 |
| 24 | + assert torch_random_value_1 == torch_random_value_2 |
0 commit comments